Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
a5c1c7f6
Unverified
Commit
a5c1c7f6
authored
Feb 10, 2019
by
Paul Fultz II
Committed by
GitHub
Feb 10, 2019
Browse files
Merge branch 'develop' into mem_color_ordering_fix
parents
462a4920
d516b099
Changes
303
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
798 additions
and
177 deletions
+798
-177
src/include/migraphx/erase.hpp
src/include/migraphx/erase.hpp
+4
-4
src/include/migraphx/errors.hpp
src/include/migraphx/errors.hpp
+5
-5
src/include/migraphx/fallthrough.hpp
src/include/migraphx/fallthrough.hpp
+6
-6
src/include/migraphx/float_equal.hpp
src/include/migraphx/float_equal.hpp
+6
-6
src/include/migraphx/functional.hpp
src/include/migraphx/functional.hpp
+10
-4
src/include/migraphx/fwd_conv_batchnorm_rewrite.hpp
src/include/migraphx/fwd_conv_batchnorm_rewrite.hpp
+4
-4
src/include/migraphx/generate.hpp
src/include/migraphx/generate.hpp
+7
-7
src/include/migraphx/half.hpp
src/include/migraphx/half.hpp
+4
-4
src/include/migraphx/instruction.hpp
src/include/migraphx/instruction.hpp
+10
-5
src/include/migraphx/instruction_ref.hpp
src/include/migraphx/instruction_ref.hpp
+4
-4
src/include/migraphx/iterator_for.hpp
src/include/migraphx/iterator_for.hpp
+6
-6
src/include/migraphx/literal.hpp
src/include/migraphx/literal.hpp
+4
-4
src/include/migraphx/make_shared_array.hpp
src/include/migraphx/make_shared_array.hpp
+5
-5
src/include/migraphx/manage_ptr.hpp
src/include/migraphx/manage_ptr.hpp
+5
-5
src/include/migraphx/matcher.hpp
src/include/migraphx/matcher.hpp
+12
-13
src/include/migraphx/memory_coloring.hpp
src/include/migraphx/memory_coloring.hpp
+4
-4
src/include/migraphx/onnx.hpp
src/include/migraphx/onnx.hpp
+22
-4
src/include/migraphx/operation.hpp
src/include/migraphx/operation.hpp
+189
-15
src/include/migraphx/operators.hpp
src/include/migraphx/operators.hpp
+440
-72
src/include/migraphx/par_dfor.hpp
src/include/migraphx/par_dfor.hpp
+51
-0
No files found.
src/include/migraphx/erase.hpp
View file @
a5c1c7f6
#ifndef MIGRAPH_GUARD_ERASE_HPP
#define MIGRAPH_GUARD_ERASE_HPP
#ifndef MIGRAPH
X
_GUARD_ERASE_HPP
#define MIGRAPH
X
_GUARD_ERASE_HPP
#include <algorithm>
#include <migraphx/config.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPH_INLINE_NS
{
inline
namespace
MIGRAPH
X
_INLINE_NS
{
/**
* @brief Erase all elements from a container
...
...
@@ -33,7 +33,7 @@ auto erase_if(R&& r, P&& pred)
return
r
.
erase
(
std
::
remove_if
(
r
.
begin
(),
r
.
end
(),
pred
),
r
.
end
());
}
}
// namespace MIGRAPH_INLINE_NS
}
// namespace MIGRAPH
X
_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/errors.hpp
View file @
a5c1c7f6
#ifndef MIGRAPH_GUARD_ERRORS_HPP
#define MIGRAPH_GUARD_ERRORS_HPP
#ifndef MIGRAPH
X
_GUARD_ERRORS_HPP
#define MIGRAPH
X
_GUARD_ERRORS_HPP
#include <exception>
#include <stdexcept>
...
...
@@ -7,7 +7,7 @@
#include <migraphx/config.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPH_INLINE_NS
{
inline
namespace
MIGRAPH
X
_INLINE_NS
{
/// Represents exceptions that can be thrown by migraphxlib
struct
exception
:
std
::
runtime_error
...
...
@@ -43,10 +43,10 @@ inline std::string make_source_context(const std::string& file, int line)
/**
* @brief Throw an exception with context information
*/
#define MIGRAPH_THROW(...) \
#define MIGRAPH
X
_THROW(...) \
throw migraphx::make_exception(migraphx::make_source_context(__FILE__, __LINE__), __VA_ARGS__)
}
// namespace MIGRAPH_INLINE_NS
}
// namespace MIGRAPH
X
_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/fallthrough.hpp
View file @
a5c1c7f6
#ifndef MIGRAPH_GUARD_FALLTHROUGH_HPP
#define MIGRAPH_GUARD_FALLTHROUGH_HPP
#ifndef MIGRAPH
X
_GUARD_FALLTHROUGH_HPP
#define MIGRAPH
X
_GUARD_FALLTHROUGH_HPP
#include <migraphx/config.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPH_INLINE_NS
{
inline
namespace
MIGRAPH
X
_INLINE_NS
{
#ifdef __clang__
#define MIGRAPH_FALLTHROUGH [[clang::fallthrough]]
#define MIGRAPH
X
_FALLTHROUGH [[clang::fallthrough]]
#else
#define MIGRAPH_FALLTHROUGH
#define MIGRAPH
X
_FALLTHROUGH
#endif
}
// namespace MIGRAPH_INLINE_NS
}
// namespace MIGRAPH
X
_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/float_equal.hpp
View file @
a5c1c7f6
#ifndef MIGRAPH_GUARD_MIGRAPHLIB_FLOAT_EQUAL_HPP
#define MIGRAPH_GUARD_MIGRAPHLIB_FLOAT_EQUAL_HPP
#ifndef MIGRAPH
X
_GUARD_MIGRAPHLIB_FLOAT_EQUAL_HPP
#define MIGRAPH
X
_GUARD_MIGRAPHLIB_FLOAT_EQUAL_HPP
#include <algorithm>
#include <cmath>
...
...
@@ -12,14 +12,14 @@
#include <migraphx/config.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPH_INLINE_NS
{
inline
namespace
MIGRAPH
X
_INLINE_NS
{
template
<
class
...
Ts
>
using
common_type
=
typename
std
::
common_type
<
Ts
...
>::
type
;
struct
float_equal_fn
{
template
<
class
T
,
MIGRAPH_REQUIRES
(
std
::
is_floating_point
<
T
>{})
>
template
<
class
T
,
MIGRAPH
X
_REQUIRES
(
std
::
is_floating_point
<
T
>{})
>
static
bool
apply
(
T
x
,
T
y
)
{
return
std
::
isfinite
(
x
)
and
std
::
isfinite
(
y
)
and
...
...
@@ -27,7 +27,7 @@ struct float_equal_fn
std
::
nextafter
(
x
,
std
::
numeric_limits
<
T
>::
max
())
>=
y
;
}
template
<
class
T
,
MIGRAPH_REQUIRES
(
not
std
::
is_floating_point
<
T
>{})
>
template
<
class
T
,
MIGRAPH
X
_REQUIRES
(
not
std
::
is_floating_point
<
T
>{})
>
static
bool
apply
(
T
x
,
T
y
)
{
return
x
==
y
;
...
...
@@ -42,7 +42,7 @@ struct float_equal_fn
static
constexpr
float_equal_fn
float_equal
{};
}
// namespace MIGRAPH_INLINE_NS
}
// namespace MIGRAPH
X
_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/functional.hpp
View file @
a5c1c7f6
#ifndef MIGRAPH_GUARD_RTGLIB_FUNCTIONAL_HPP
#define MIGRAPH_GUARD_RTGLIB_FUNCTIONAL_HPP
#ifndef MIGRAPH
X
_GUARD_RTGLIB_FUNCTIONAL_HPP
#define MIGRAPH
X
_GUARD_RTGLIB_FUNCTIONAL_HPP
#include <utility>
#include <migraphx/config.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPH_INLINE_NS
{
inline
namespace
MIGRAPH
X
_INLINE_NS
{
struct
swallow
{
...
...
@@ -94,6 +94,12 @@ constexpr void each_args(F)
{
}
template
<
class
F
,
class
T
>
auto
unpack
(
F
f
,
T
&
x
)
{
return
sequence_c
<
std
::
tuple_size
<
T
>
{}
>
([
&
](
auto
...
is
)
{
f
(
std
::
get
<
is
>
(
x
)...);
});
}
/// Implements a fix-point combinator
template
<
class
R
,
class
F
>
detail
::
fix_f
<
R
,
F
>
fix
(
F
f
)
...
...
@@ -131,7 +137,7 @@ auto fold(F f)
return
[
=
](
auto
&&
...
xs
)
{
return
fold_impl
(
f
,
std
::
forward
<
decltype
(
xs
)
>
(
xs
)...);
};
}
}
// namespace MIGRAPH_INLINE_NS
}
// namespace MIGRAPH
X
_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/fwd_conv_batchnorm_rewrite.hpp
View file @
a5c1c7f6
#ifndef MIGRAPH_GUARD_RTGLIB_FWD_CONV_BATCHNORM_REWRITE_HPP
#define MIGRAPH_GUARD_RTGLIB_FWD_CONV_BATCHNORM_REWRITE_HPP
#ifndef MIGRAPH
X
_GUARD_RTGLIB_FWD_CONV_BATCHNORM_REWRITE_HPP
#define MIGRAPH
X
_GUARD_RTGLIB_FWD_CONV_BATCHNORM_REWRITE_HPP
#include <string>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/config.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPH_INLINE_NS
{
inline
namespace
MIGRAPH
X
_INLINE_NS
{
struct
program
;
...
...
@@ -19,7 +19,7 @@ struct fwd_conv_batchnorm_rewrite
void
apply
(
program
&
p
)
const
;
};
}
// namespace MIGRAPH_INLINE_NS
}
// namespace MIGRAPH
X
_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/generate.hpp
View file @
a5c1c7f6
#ifndef MIGRAPH_GUARD_MIGRAPHLIB_GENERATE_HPP
#define MIGRAPH_GUARD_MIGRAPHLIB_GENERATE_HPP
#ifndef MIGRAPH
X
_GUARD_MIGRAPHLIB_GENERATE_HPP
#define MIGRAPH
X
_GUARD_MIGRAPHLIB_GENERATE_HPP
#include <migraphx/argument.hpp>
#include <migraphx/literal.hpp>
...
...
@@ -8,9 +8,9 @@
#include <random>
namespace
migraphx
{
inline
namespace
MIGRAPH_INLINE_NS
{
inline
namespace
MIGRAPH
X
_INLINE_NS
{
template
<
class
T
,
MIGRAPH_REQUIRES
(
is_floating_point
<
T
>{})
>
template
<
class
T
,
MIGRAPH
X
_REQUIRES
(
is_floating_point
<
T
>{})
>
constexpr
T
normalize
(
unsigned
long
z
)
{
if
(
z
==
0
)
...
...
@@ -22,7 +22,7 @@ constexpr T normalize(unsigned long z)
return
T
(
result
);
}
template
<
class
T
,
MIGRAPH_REQUIRES
(
is_signed
<
T
>{}
and
not
is_floating_point
<
T
>
{})
>
template
<
class
T
,
MIGRAPH
X
_REQUIRES
(
is_signed
<
T
>{}
and
not
is_floating_point
<
T
>
{})
>
constexpr
T
normalize
(
unsigned
long
z
)
{
const
auto
max
=
std
::
numeric_limits
<
T
>::
max
();
...
...
@@ -30,7 +30,7 @@ constexpr T normalize(unsigned long z)
return
half_max
-
(
z
%
max
);
}
template
<
class
T
,
MIGRAPH_REQUIRES
(
not
is_signed
<
T
>{}
and
std
::
is_integral
<
T
>
{})
>
template
<
class
T
,
MIGRAPH
X
_REQUIRES
(
not
is_signed
<
T
>{}
and
std
::
is_integral
<
T
>
{})
>
constexpr
T
normalize
(
unsigned
long
z
)
{
const
auto
max
=
std
::
numeric_limits
<
T
>::
max
();
...
...
@@ -93,7 +93,7 @@ literal generate_literal(shape s, unsigned long seed = 0);
literal
abs
(
literal
l
);
}
// namespace MIGRAPH_INLINE_NS
}
// namespace MIGRAPH
X
_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/half.hpp
View file @
a5c1c7f6
...
...
@@ -5,14 +5,14 @@
file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
==============================================================================*/
#ifndef MIGRAPH_GUARD_RTGLIB_HALF_HPP
#define MIGRAPH_GUARD_RTGLIB_HALF_HPP
#ifndef MIGRAPH
X
_GUARD_RTGLIB_HALF_HPP
#define MIGRAPH
X
_GUARD_RTGLIB_HALF_HPP
#include <half.hpp>
#include <migraphx/config.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPH_INLINE_NS
{
inline
namespace
MIGRAPH
X
_INLINE_NS
{
using
half
=
half_float
::
half
;
...
...
@@ -33,7 +33,7 @@ struct deduce<half_float::detail::expr>
template
<
class
T
>
using
deduce
=
typename
detail
::
deduce
<
T
>::
type
;
}
// namespace MIGRAPH_INLINE_NS
}
// namespace MIGRAPH
X
_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/instruction.hpp
View file @
a5c1c7f6
#ifndef MIGRAPH_GUARD_MIGRAPHLIB_INSTRUCTION_HPP
#define MIGRAPH_GUARD_MIGRAPHLIB_INSTRUCTION_HPP
#ifndef MIGRAPH
X
_GUARD_MIGRAPHLIB_INSTRUCTION_HPP
#define MIGRAPH
X
_GUARD_MIGRAPHLIB_INSTRUCTION_HPP
#include <migraphx/literal.hpp>
#include <migraphx/shape.hpp>
...
...
@@ -11,9 +11,10 @@
#include <utility>
namespace
migraphx
{
inline
namespace
MIGRAPH_INLINE_NS
{
inline
namespace
MIGRAPH
X
_INLINE_NS
{
shape
compute_shape
(
const
operation
&
op
,
const
std
::
vector
<
instruction_ref
>&
args
);
std
::
vector
<
shape
>
to_shapes
(
const
std
::
vector
<
instruction_ref
>&
args
);
struct
instruction
{
...
...
@@ -71,7 +72,11 @@ struct instruction
static
void
replace
(
instruction_ref
ins
,
operation
o
,
const
shape
&
r
,
std
::
vector
<
instruction_ref
>
args
);
static
instruction_ref
get_output_alias
(
instruction_ref
ins
);
argument
eval
()
const
;
void
finalize
(
context
&
ctx
);
static
instruction_ref
get_output_alias
(
instruction_ref
ins
,
bool
shallow
=
false
);
private:
// internal
...
...
@@ -90,7 +95,7 @@ struct instruction
std
::
vector
<
instruction_ref
>
arguments
;
literal
lit
;
};
}
// namespace MIGRAPH_INLINE_NS
}
// namespace MIGRAPH
X
_INLINE_NS
}
// namespace migraphx
namespace
std
{
...
...
src/include/migraphx/instruction_ref.hpp
View file @
a5c1c7f6
#ifndef MIGRAPH_GUARD_INSTRUCTION_REF_HPP
#define MIGRAPH_GUARD_INSTRUCTION_REF_HPP
#ifndef MIGRAPH
X
_GUARD_INSTRUCTION_REF_HPP
#define MIGRAPH
X
_GUARD_INSTRUCTION_REF_HPP
#include <list>
#include <functional>
#include <migraphx/config.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPH_INLINE_NS
{
inline
namespace
MIGRAPH
X
_INLINE_NS
{
struct
instruction
;
using
instruction_ref
=
std
::
list
<
instruction
>::
iterator
;
}
// namespace MIGRAPH_INLINE_NS
}
// namespace MIGRAPH
X
_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/iterator_for.hpp
View file @
a5c1c7f6
#ifndef MIGRAPH_GUARD_RTGLIB_ITERATOR_FOR_HPP
#define MIGRAPH_GUARD_RTGLIB_ITERATOR_FOR_HPP
#ifndef MIGRAPH
X
_GUARD_RTGLIB_ITERATOR_FOR_HPP
#define MIGRAPH
X
_GUARD_RTGLIB_ITERATOR_FOR_HPP
#include <cassert>
#include <type_traits>
#include <migraphx/config.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPH_INLINE_NS
{
inline
namespace
MIGRAPH
X
_INLINE_NS
{
template
<
class
T
>
struct
iterator_for_range
...
...
@@ -17,9 +17,9 @@ struct iterator_for_range
struct
iterator
{
base_iterator
i
;
base_iterator
operator
*
()
{
return
i
;
}
base_iterator
operator
*
()
const
{
return
i
;
}
base_iterator
operator
++
()
{
return
++
i
;
}
bool
operator
!=
(
const
iterator
&
rhs
)
{
return
i
!=
rhs
.
i
;
}
bool
operator
!=
(
const
iterator
&
rhs
)
const
{
return
i
!=
rhs
.
i
;
}
};
iterator
begin
()
...
...
@@ -39,7 +39,7 @@ iterator_for_range<T> iterator_for(T& x)
return
{
&
x
};
}
}
// namespace MIGRAPH_INLINE_NS
}
// namespace MIGRAPH
X
_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/literal.hpp
View file @
a5c1c7f6
#ifndef MIGRAPH_GUARD_MIGRAPHLIB_LITERAL_HPP
#define MIGRAPH_GUARD_MIGRAPHLIB_LITERAL_HPP
#ifndef MIGRAPH
X
_GUARD_MIGRAPHLIB_LITERAL_HPP
#define MIGRAPH
X
_GUARD_MIGRAPHLIB_LITERAL_HPP
#include <migraphx/shape.hpp>
#include <migraphx/shape_for_each.hpp>
...
...
@@ -12,7 +12,7 @@
#include <memory>
namespace
migraphx
{
inline
namespace
MIGRAPH_INLINE_NS
{
inline
namespace
MIGRAPH
X
_INLINE_NS
{
/**
* @brief Represents a raw literal
...
...
@@ -124,7 +124,7 @@ literal transform(literal l1, literal l2, F f)
return
result
;
}
}
// namespace MIGRAPH_INLINE_NS
}
// namespace MIGRAPH
X
_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/make_shared_array.hpp
View file @
a5c1c7f6
#ifndef MIGRAPH_GUARD_MIGRAPHLIB_MAKE_SHARED_ARRAY_HPP
#define MIGRAPH_GUARD_MIGRAPHLIB_MAKE_SHARED_ARRAY_HPP
#ifndef MIGRAPH
X
_GUARD_MIGRAPHLIB_MAKE_SHARED_ARRAY_HPP
#define MIGRAPH
X
_GUARD_MIGRAPHLIB_MAKE_SHARED_ARRAY_HPP
#include <memory>
#include <migraphx/config.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPH_INLINE_NS
{
inline
namespace
MIGRAPH
X
_INLINE_NS
{
template
<
typename
T
>
std
::
shared_ptr
<
T
>
make_shared_array
(
size_t
size
)
{
return
std
::
shared_ptr
<
T
>
(
new
T
[
size
],
std
::
default_delete
<
T
[]
>
());
return
std
::
shared_ptr
<
T
>
(
new
T
[
size
],
std
::
default_delete
<
T
[]
>
());
// NOLINT
}
}
// namespace MIGRAPH_INLINE_NS
}
// namespace MIGRAPH
X
_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/manage_ptr.hpp
View file @
a5c1c7f6
#ifndef MIGRAPH_GUARD_MIGRAPH_MANAGE_PTR_HPP
#define MIGRAPH_GUARD_MIGRAPH_MANAGE_PTR_HPP
#ifndef MIGRAPH
X
_GUARD_MIGRAPH
X
_MANAGE_PTR_HPP
#define MIGRAPH
X
_GUARD_MIGRAPH
X
_MANAGE_PTR_HPP
#include <memory>
#include <type_traits>
#include <migraphx/config.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPH_INLINE_NS
{
inline
namespace
MIGRAPH
X
_INLINE_NS
{
template
<
class
F
,
F
f
>
// NOLINT
struct
manage_deleter
...
...
@@ -51,10 +51,10 @@ shared<T> share(T p)
return
shared
<
T
>
{
std
::
move
(
p
)};
}
}
// namespace MIGRAPH_INLINE_NS
}
// namespace MIGRAPH
X
_INLINE_NS
}
// namespace migraphx
#define MIGRAPH_MANAGE_PTR(T, F) \
#define MIGRAPH
X
_MANAGE_PTR(T, F) \
migraphx::manage_ptr<std::remove_pointer_t<T>, decltype(&F), &F> // NOLINT
#endif
src/include/migraphx/matcher.hpp
View file @
a5c1c7f6
#ifndef MIGRAPH_GUARD_RTGLIB_MATCHER_HPP
#define MIGRAPH_GUARD_RTGLIB_MATCHER_HPP
#ifndef MIGRAPH
X
_GUARD_RTGLIB_MATCHER_HPP
#define MIGRAPH
X
_GUARD_RTGLIB_MATCHER_HPP
#include <migraphx/functional.hpp>
#include <migraphx/ranges.hpp>
...
...
@@ -10,7 +10,7 @@
#include <unordered_map>
namespace
migraphx
{
inline
namespace
MIGRAPH_INLINE_NS
{
inline
namespace
MIGRAPH
X
_INLINE_NS
{
namespace
match
{
...
...
@@ -169,7 +169,7 @@ basic_matcher<predicate_matcher<P>> make_basic_pred_matcher(P p)
}
/// This macro takes care of the boilerplate for defining a matcher
#define MIGRAPH_BASIC_MATCHER(name, ...)
\
#define MIGRAPH
X
_BASIC_MATCHER(name, ...) \
struct name##_m \
{ \
instruction_ref match(__VA_ARGS__) const; \
...
...
@@ -178,7 +178,7 @@ basic_matcher<predicate_matcher<P>> make_basic_pred_matcher(P p)
inline instruction_ref name##_m::match(__VA_ARGS__) const
/// This macro takes care of the boilerplate for defining a predicate matcher
#define MIGRAPH_PRED_MATCHER(name, ...)
\
#define MIGRAPH
X
_PRED_MATCHER(name, ...) \
struct name##_m \
{ \
bool operator()(__VA_ARGS__) const; \
...
...
@@ -214,7 +214,6 @@ void find_matches(program& p, Ms&&... ms)
bool
match
=
false
;
each_args
(
[
&
](
auto
&&
m
)
{
// cppcheck-suppress knownConditionTrueFalse
if
(
match
)
return
;
auto
r
=
match_instruction
(
p
,
ins
,
m
.
matcher
());
...
...
@@ -266,22 +265,22 @@ auto any_of(Ts... ms)
});
}
MIGRAPH_PRED_MATCHER
(
any
,
instruction_ref
)
{
return
true
;
}
MIGRAPH_PRED_MATCHER
(
none
,
instruction_ref
)
{
return
false
;
}
MIGRAPH_PRED_MATCHER
(
standard_shape
,
instruction_ref
ins
)
{
return
ins
->
get_shape
().
standard
();
}
MIGRAPH_PRED_MATCHER
(
broadcast_shape
,
instruction_ref
ins
)
MIGRAPH
X
_PRED_MATCHER
(
any
,
instruction_ref
)
{
return
true
;
}
MIGRAPH
X
_PRED_MATCHER
(
none
,
instruction_ref
)
{
return
false
;
}
MIGRAPH
X
_PRED_MATCHER
(
standard_shape
,
instruction_ref
ins
)
{
return
ins
->
get_shape
().
standard
();
}
MIGRAPH
X
_PRED_MATCHER
(
broadcast_shape
,
instruction_ref
ins
)
{
return
ins
->
get_shape
().
broadcasted
();
}
MIGRAPH_BASIC_MATCHER
(
output
,
matcher_context
&
ctx
,
instruction_ref
ins
)
MIGRAPH
X
_BASIC_MATCHER
(
output
,
matcher_context
&
ctx
,
instruction_ref
ins
)
{
if
(
ins
->
outputs
().
size
()
==
1
)
return
ins
->
outputs
().
front
();
return
ctx
.
not_found
();
}
MIGRAPH_BASIC_MATCHER
(
used_once
,
matcher_context
&
ctx
,
instruction_ref
ins
)
MIGRAPH
X
_BASIC_MATCHER
(
used_once
,
matcher_context
&
ctx
,
instruction_ref
ins
)
{
if
(
ins
->
outputs
().
size
()
==
1
)
return
ins
;
...
...
@@ -340,7 +339,7 @@ inline auto either_arg(std::size_t i, std::size_t j)
}
}
// namespace match
}
// namespace MIGRAPH_INLINE_NS
}
// namespace MIGRAPH
X
_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/memory_coloring.hpp
View file @
a5c1c7f6
#ifndef MIGRAPH_GUARD_RTGLIB_MEMORY_COLORING_HPP
#define MIGRAPH_GUARD_RTGLIB_MEMORY_COLORING_HPP
#ifndef MIGRAPH
X
_GUARD_RTGLIB_MEMORY_COLORING_HPP
#define MIGRAPH
X
_GUARD_RTGLIB_MEMORY_COLORING_HPP
#include <string>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/config.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPH_INLINE_NS
{
inline
namespace
MIGRAPH
X
_INLINE_NS
{
struct
program
;
/**
...
...
@@ -20,7 +20,7 @@ struct memory_coloring
void
apply
(
program
&
p
)
const
;
};
}
// namespace MIGRAPH_INLINE_NS
}
// namespace MIGRAPH
X
_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/onnx.hpp
View file @
a5c1c7f6
#ifndef MIGRAPH_GUARD_MIGRAPHLIB_ONNX_HPP
#define MIGRAPH_GUARD_MIGRAPHLIB_ONNX_HPP
#ifndef MIGRAPH
X
_GUARD_MIGRAPHLIB_ONNX_HPP
#define MIGRAPH
X
_GUARD_MIGRAPHLIB_ONNX_HPP
#include <migraphx/program.hpp>
#include <migraphx/config.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPH_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
struct
unknown
{
std
::
string
op
;
std
::
string
name
()
const
{
return
"unknown:"
+
op
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
input
)
const
{
if
(
input
.
empty
())
return
{};
else
return
input
.
front
();
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
unknown
&
x
)
{
os
<<
x
.
name
();
return
os
;
}
};
/// Create a program from an onnx file
program
parse_onnx
(
const
std
::
string
&
name
);
}
// namespace MIGRAPH_INLINE_NS
}
// namespace MIGRAPH
X
_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/operation.hpp
View file @
a5c1c7f6
#ifndef MIGRAPH_GUARD_MIGRAPHLIB_OPERAND_HPP
#define MIGRAPH_GUARD_MIGRAPHLIB_OPERAND_HPP
#ifndef MIGRAPH
X
_GUARD_MIGRAPHLIB_OPERAND_HPP
#define MIGRAPH
X
_GUARD_MIGRAPHLIB_OPERAND_HPP
#include <cassert>
#include <string>
...
...
@@ -7,16 +7,16 @@
#include <memory>
#include <type_traits>
#include <utility>
#include <migraphx/shape.hpp>
#include <migraphx/reflect.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/context.hpp>
#include <migraphx/auto_any_cast.hpp>
#include <migraphx/config.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPH_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
struct
context
;
#ifdef DOXYGEN
...
...
@@ -26,6 +26,8 @@ struct operation
{
/// A unique name identifying the operation
std
::
string
name
()
const
;
/// An optional method that can be used to finalize the operator before running
void
finalize
(
context
&
ctx
);
/// This is used to compute the resulting shape from an operation. If an
/// operation cannot be run with input shapes, then it should throw an
/// exception.
...
...
@@ -53,6 +55,11 @@ struct operation
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
operation
&
op
);
};
/// Returns true if operation does not require a context to run compute
bool
is_context_free
(
const
operation
&
x
);
/// Returns true if the operation has a finalize method
bool
has_finalize
(
const
operation
&
x
);
#else
namespace
operation_stream
{
...
...
@@ -89,7 +96,7 @@ auto operator==(const T& x, const U& y) -> decltype(x.name() == y.name())
}
// namespace operation_equal
template
<
class
T
>
auto
compute_op
(
rank
<
1
>
,
auto
compute_op
(
rank
<
2
>
,
const
T
&
x
,
context
&
ctx
,
const
shape
&
output_shape
,
...
...
@@ -99,18 +106,72 @@ auto compute_op(rank<1>,
return
x
.
compute
(
auto_any_cast
(
ctx
),
output_shape
,
input
);
}
template
<
class
T
>
auto
compute_op
(
rank
<
1
>
,
const
T
&
x
,
context
&
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
input
)
->
decltype
(
x
.
compute
(
output_shape
,
input
))
{
return
x
.
compute
(
output_shape
,
input
);
}
template
<
class
T
>
argument
compute_op
(
rank
<
0
>
,
const
T
&
x
,
context
&
,
const
shape
&
,
const
std
::
vector
<
argument
>&
)
{
std
::
string
name
=
x
.
name
();
MIGRAPH_THROW
(
"Not computable: "
+
name
);
MIGRAPH
X
_THROW
(
"Not computable: "
+
name
);
}
template
<
class
T
>
argument
compute_op
(
const
T
&
x
,
context
&
ctx
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
input
)
{
return
compute_op
(
rank
<
1
>
{},
x
,
ctx
,
output_shape
,
input
);
return
compute_op
(
rank
<
2
>
{},
x
,
ctx
,
output_shape
,
input
);
}
template
<
class
T
>
auto
compute_op
(
rank
<
2
>
,
const
T
&
x
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
input
)
->
decltype
(
x
.
compute
(
output_shape
,
input
))
{
return
x
.
compute
(
output_shape
,
input
);
}
template
<
class
T
>
auto
compute_op
(
rank
<
1
>
,
const
T
&
x
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
input
)
->
decltype
(
x
.
compute
(
auto_any_cast
(
std
::
declval
<
context
&>
()),
output_shape
,
input
))
{
std
::
string
name
=
x
.
name
();
MIGRAPHX_THROW
(
"Not computable without a context: "
+
name
);
}
template
<
class
T
>
argument
compute_op
(
rank
<
0
>
,
const
T
&
x
,
const
shape
&
,
const
std
::
vector
<
argument
>&
)
{
std
::
string
name
=
x
.
name
();
MIGRAPHX_THROW
(
"Not computable: "
+
name
);
}
template
<
class
T
>
argument
compute_op
(
const
T
&
x
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
input
)
{
return
compute_op
(
rank
<
2
>
{},
x
,
output_shape
,
input
);
}
template
<
class
T
>
auto
is_context_free_op
(
rank
<
1
>
,
const
T
&
x
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
input
)
->
decltype
(
x
.
compute
(
output_shape
,
input
),
std
::
true_type
{});
template
<
class
T
>
auto
is_context_free_op
(
rank
<
0
>
,
const
T
&
,
const
shape
&
,
const
std
::
vector
<
argument
>&
)
->
std
::
false_type
;
template
<
class
T
>
auto
is_context_free_op
(
const
T
&
x
)
->
decltype
(
is_context_free_op
(
rank
<
1
>
{},
x
,
std
::
declval
<
const
shape
&>
(),
std
::
declval
<
std
::
vector
<
argument
>>
()))
{
return
{};
}
template
<
class
T
>
...
...
@@ -132,15 +193,57 @@ int output_alias_op(const T& x, const std::vector<shape>& shapes)
return
output_alias_op
(
rank
<
1
>
{},
x
,
shapes
);
}
template
<
class
T
>
auto
finalize_op
(
rank
<
1
>
,
T
&
x
,
context
&
ctx
,
const
shape
&
output_shape
,
const
std
::
vector
<
shape
>&
input
)
->
decltype
(
x
.
finalize
(
auto_any_cast
(
ctx
),
output_shape
,
input
),
void
())
{
x
.
finalize
(
auto_any_cast
(
ctx
),
output_shape
,
input
);
}
template
<
class
T
>
void
finalize_op
(
rank
<
0
>
,
T
&
,
context
&
,
const
shape
&
,
const
std
::
vector
<
shape
>&
)
{
}
template
<
class
T
>
void
finalize_op
(
T
&
x
,
context
&
ctx
,
const
shape
&
output_shape
,
const
std
::
vector
<
shape
>&
input
)
{
finalize_op
(
rank
<
1
>
{},
x
,
ctx
,
output_shape
,
input
);
}
template
<
class
T
>
auto
has_finalize_op
(
rank
<
1
>
,
T
&
x
,
context
&
ctx
,
const
shape
&
output_shape
,
const
std
::
vector
<
shape
>&
input
)
->
decltype
(
x
.
finalize
(
auto_any_cast
(
ctx
),
output_shape
,
input
),
std
::
true_type
{});
template
<
class
T
>
auto
has_finalize_op
(
rank
<
0
>
,
T
&
,
context
&
,
const
shape
&
,
const
std
::
vector
<
shape
>&
)
->
std
::
false_type
;
template
<
class
T
>
auto
has_finalize_op
(
const
T
&
)
->
decltype
(
has_finalize_op
(
rank
<
1
>
{},
std
::
declval
<
T
&>
(),
std
::
declval
<
context
&>
(),
std
::
declval
<
const
shape
&>
(),
std
::
declval
<
std
::
vector
<
shape
>>
()))
{
return
{};
}
/*
* Type-erased interface for:
*
* struct operation
* {
* std::string name() const;
* bool is_context_free() const;
* bool has_finalize() const;
* int output_alias(const std::vector<shape>& input) const;
* void finalize(context& ctx,const shape& output,const std::vector<shape>& input) ;
* shape compute_shape(const std::vector<shape>& input) const;
* argument compute(context& ctx,const shape& output,const std::vector<argument>& input) const;
* argument compute(const shape& output,const std::vector<argument>& input) const;
* friend std::ostream & operator<<(std::ostream & os,const operation & op) ;
* friend bool operator==(const operation & x,const operation & y) ;
* };
...
...
@@ -210,12 +313,30 @@ struct operation
return
(
*
this
).
private_detail_te_get_handle
().
name
();
}
bool
is_context_free
()
const
{
assert
((
*
this
).
private_detail_te_handle_mem_var
);
return
(
*
this
).
private_detail_te_get_handle
().
is_context_free
();
}
bool
has_finalize
()
const
{
assert
((
*
this
).
private_detail_te_handle_mem_var
);
return
(
*
this
).
private_detail_te_get_handle
().
has_finalize
();
}
int
output_alias
(
const
std
::
vector
<
shape
>&
input
)
const
{
assert
((
*
this
).
private_detail_te_handle_mem_var
);
return
(
*
this
).
private_detail_te_get_handle
().
output_alias
(
input
);
}
void
finalize
(
context
&
ctx
,
const
shape
&
output
,
const
std
::
vector
<
shape
>&
input
)
{
assert
((
*
this
).
private_detail_te_handle_mem_var
);
(
*
this
).
private_detail_te_get_handle
().
finalize
(
ctx
,
output
,
input
);
}
shape
compute_shape
(
const
std
::
vector
<
shape
>&
input
)
const
{
assert
((
*
this
).
private_detail_te_handle_mem_var
);
...
...
@@ -228,6 +349,12 @@ struct operation
return
(
*
this
).
private_detail_te_get_handle
().
compute
(
ctx
,
output
,
input
);
}
argument
compute
(
const
shape
&
output
,
const
std
::
vector
<
argument
>&
input
)
const
{
assert
((
*
this
).
private_detail_te_handle_mem_var
);
return
(
*
this
).
private_detail_te_get_handle
().
compute
(
output
,
input
);
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
operation
&
op
)
{
assert
(
op
.
private_detail_te_handle_mem_var
);
...
...
@@ -240,6 +367,12 @@ struct operation
return
x
.
private_detail_te_get_handle
().
operator
==
(
y
);
}
friend
bool
is_shared
(
const
operation
&
private_detail_x
,
const
operation
&
private_detail_y
)
{
return
private_detail_x
.
private_detail_te_handle_mem_var
==
private_detail_y
.
private_detail_te_handle_mem_var
;
}
private:
struct
private_detail_te_handle_base_type
{
...
...
@@ -247,13 +380,18 @@ struct operation
virtual
std
::
shared_ptr
<
private_detail_te_handle_base_type
>
clone
()
const
=
0
;
virtual
const
std
::
type_info
&
type
()
const
=
0
;
virtual
std
::
string
name
()
const
=
0
;
virtual
int
output_alias
(
const
std
::
vector
<
shape
>&
input
)
const
=
0
;
virtual
shape
compute_shape
(
const
std
::
vector
<
shape
>&
input
)
const
=
0
;
virtual
std
::
string
name
()
const
=
0
;
virtual
bool
is_context_free
()
const
=
0
;
virtual
bool
has_finalize
()
const
=
0
;
virtual
int
output_alias
(
const
std
::
vector
<
shape
>&
input
)
const
=
0
;
virtual
void
finalize
(
context
&
ctx
,
const
shape
&
output
,
const
std
::
vector
<
shape
>&
input
)
=
0
;
virtual
shape
compute_shape
(
const
std
::
vector
<
shape
>&
input
)
const
=
0
;
virtual
argument
compute
(
context
&
ctx
,
const
shape
&
output
,
const
std
::
vector
<
argument
>&
input
)
const
=
0
;
virtual
std
::
ostream
&
operator_shift_left
(
std
::
ostream
&
os
)
const
=
0
;
virtual
bool
operator
==
(
const
operation
&
y
)
const
=
0
;
compute
(
context
&
ctx
,
const
shape
&
output
,
const
std
::
vector
<
argument
>&
input
)
const
=
0
;
virtual
argument
compute
(
const
shape
&
output
,
const
std
::
vector
<
argument
>&
input
)
const
=
0
;
virtual
std
::
ostream
&
operator_shift_left
(
std
::
ostream
&
os
)
const
=
0
;
virtual
bool
operator
==
(
const
operation
&
y
)
const
=
0
;
};
template
<
typename
PrivateDetailTypeErasedT
>
...
...
@@ -286,12 +424,26 @@ struct operation
std
::
string
name
()
const
override
{
return
private_detail_te_value
.
name
();
}
bool
is_context_free
()
const
override
{
return
is_context_free_op
(
private_detail_te_value
);
}
bool
has_finalize
()
const
override
{
return
has_finalize_op
(
private_detail_te_value
);
}
int
output_alias
(
const
std
::
vector
<
shape
>&
input
)
const
override
{
return
output_alias_op
(
private_detail_te_value
,
input
);
}
void
finalize
(
context
&
ctx
,
const
shape
&
output
,
const
std
::
vector
<
shape
>&
input
)
override
{
finalize_op
(
private_detail_te_value
,
ctx
,
output
,
input
);
}
shape
compute_shape
(
const
std
::
vector
<
shape
>&
input
)
const
override
{
...
...
@@ -306,6 +458,12 @@ struct operation
return
compute_op
(
private_detail_te_value
,
ctx
,
output
,
input
);
}
argument
compute
(
const
shape
&
output
,
const
std
::
vector
<
argument
>&
input
)
const
override
{
return
compute_op
(
private_detail_te_value
,
output
,
input
);
}
std
::
ostream
&
operator_shift_left
(
std
::
ostream
&
os
)
const
override
{
using
migraphx
::
operation_stream
::
operator
<<
;
...
...
@@ -385,9 +543,25 @@ inline const ValueType& any_cast(const operation& x)
inline
bool
operator
!=
(
const
operation
&
x
,
const
operation
&
y
)
{
return
!
(
x
==
y
);
}
inline
bool
is_context_free
(
const
operation
&
op
)
{
return
op
.
is_context_free
();
}
template
<
class
T
>
bool
is_context_free
(
const
T
&
x
)
{
return
is_context_free_op
(
x
);
}
inline
bool
has_finalize
(
const
operation
&
op
)
{
return
op
.
has_finalize
();
}
template
<
class
T
>
bool
has_finalize
(
const
T
&
x
)
{
return
has_finalize_op
(
x
);
}
#endif
}
// namespace MIGRAPH_INLINE_NS
}
// namespace MIGRAPH
X
_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/operators.hpp
View file @
a5c1c7f6
#ifndef MIGRAPH_GUARD_OPERATORS_HPP
#define MIGRAPH_GUARD_OPERATORS_HPP
#ifndef MIGRAPH
X
_GUARD_OPERATORS_HPP
#define MIGRAPH
X
_GUARD_OPERATORS_HPP
#include <array>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <utility>
namespace
migraphx
{
inline
namespace
MIGRAPH_INLINE_NS
{
inline
namespace
MIGRAPH
X
_INLINE_NS
{
namespace
op
{
enum
padding_mode_t
{
default_
,
// NOLINT
same
,
valid
};
struct
not_computable
{
argument
compute
(
context
&
,
const
shape
&
,
const
std
::
vector
<
argument
>&
)
const
argument
compute
(
const
shape
&
,
const
std
::
vector
<
argument
>&
)
const
{
MIGRAPH_THROW
(
"not computable"
);
MIGRAPH
X
_THROW
(
"not computable"
);
}
};
...
...
@@ -51,18 +60,38 @@ struct batch_norm_inference
}
};
struct
lrn
{
float
alpha
=
0.0001
;
float
beta
=
0.75
;
float
bias
=
1.0
;
int
size
=
1
;
std
::
string
name
()
const
{
return
"lrn"
;
}
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
{
return
pack
(
f
(
self
.
alpha
,
"alpha"
),
f
(
self
.
beta
,
"beta"
),
f
(
self
.
bias
,
"bias"
),
f
(
self
.
size
,
"size"
));
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
check_shapes
{
inputs
,
*
this
}.
has
(
1
);
return
inputs
.
front
();
}
};
struct
convolution
{
std
::
array
<
std
::
size_t
,
2
>
padding
=
{{
0
,
0
}};
std
::
array
<
std
::
size_t
,
2
>
stride
=
{{
1
,
1
}};
std
::
array
<
std
::
size_t
,
2
>
dilation
=
{{
1
,
1
}};
enum
padding_mode_t
{
default_
,
// NOLINT
same
,
valid
};
padding_mode_t
padding_mode
=
default_
;
int
group
=
1
;
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
...
...
@@ -70,7 +99,8 @@ struct convolution
return
pack
(
f
(
self
.
padding
,
"padding"
),
f
(
self
.
stride
,
"stride"
),
f
(
self
.
dilation
,
"dilation"
),
f
(
self
.
padding_mode
,
"padding_mode"
));
f
(
self
.
padding_mode
,
"padding_mode"
),
f
(
self
.
group
,
"group"
));
}
std
::
string
name
()
const
{
return
"convolution"
;
}
...
...
@@ -124,7 +154,7 @@ struct convolution
}
else
{
MIGRAPH_THROW
(
"Invalid padding mode"
);
MIGRAPH
X
_THROW
(
"Invalid padding mode"
);
}
}
};
...
...
@@ -134,12 +164,7 @@ struct im2col
std
::
array
<
std
::
size_t
,
2
>
padding
=
{{
0
,
0
}};
std
::
array
<
std
::
size_t
,
2
>
stride
=
{{
1
,
1
}};
std
::
array
<
std
::
size_t
,
2
>
dilation
=
{{
1
,
1
}};
enum
padding_mode_t
{
default_
,
// NOLINT
same
,
valid
};
padding_mode_t
padding_mode
=
default_
;
template
<
class
Self
,
class
F
>
...
...
@@ -163,7 +188,7 @@ struct im2col
auto
kernel_width
=
weights
.
lens
()[
3
];
check_shapes
{
inputs
,
*
this
}.
has
(
2
);
if
(
batch_size
!=
1
)
MIGRAPH_THROW
(
"im2col only support batch_size 1"
);
MIGRAPH
X
_THROW
(
"im2col only support batch_size 1"
);
auto
output_height
=
std
::
size_t
(
std
::
max
<
std
::
ptrdiff_t
>
(
1
,
(
input
.
lens
()[
2
]
-
(
1
+
dilation
[
0
]
*
(
kernel_height
-
1
))
+
2
*
padding
[
0
])
/
...
...
@@ -185,12 +210,14 @@ struct pooling
std
::
array
<
std
::
size_t
,
2
>
padding
=
{{
0
,
0
}};
std
::
array
<
std
::
size_t
,
2
>
stride
=
{{
1
,
1
}};
std
::
array
<
std
::
size_t
,
2
>
lengths
=
{{
1
,
1
}};
padding_mode_t
padding_mode
=
default_
;
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
{
return
pack
(
f
(
self
.
mode
,
"mode"
),
f
(
self
.
padding
,
"padding"
),
f
(
self
.
padding
,
"padding_mode"
),
f
(
self
.
stride
,
"stride"
),
f
(
self
.
lengths
,
"lengths"
));
}
...
...
@@ -207,7 +234,10 @@ struct pooling
assert
(
lengths
[
0
]
<=
(
input
.
lens
()[
2
]
+
2
*
padding
[
0
]));
assert
(
lengths
[
1
]
<=
(
input
.
lens
()[
3
]
+
2
*
padding
[
1
]));
return
{
t
,
if
(
padding_mode
==
default_
)
{
return
{
t
,
{
input
.
lens
()[
0
],
input
.
lens
()[
1
],
...
...
@@ -222,6 +252,39 @@ struct pooling
static_cast
<
float
>
(
stride
[
1
])))
+
1
)),
}};
}
else
if
(
padding_mode
==
same
)
{
return
{
t
,
{
input
.
lens
()[
0
],
input
.
lens
()[
1
],
static_cast
<
std
::
size_t
>
(
std
::
ceil
(
static_cast
<
double
>
(
input
.
lens
()[
2
])
/
stride
[
0
])),
static_cast
<
std
::
size_t
>
(
std
::
ceil
(
static_cast
<
double
>
(
input
.
lens
()[
3
])
/
stride
[
1
]))}};
}
else
if
(
padding_mode
==
valid
)
{
return
{
t
,
{
input
.
lens
()[
0
],
input
.
lens
()[
1
],
std
::
size_t
(
std
::
max
<
std
::
ptrdiff_t
>
(
1
,
std
::
ptrdiff_t
(
std
::
floor
((
input
.
lens
()[
2
]
-
lengths
[
0
])
/
static_cast
<
float
>
(
stride
[
0
])))
+
1
)),
std
::
size_t
(
std
::
max
<
std
::
ptrdiff_t
>
(
1
,
std
::
ptrdiff_t
(
std
::
floor
((
input
.
lens
()[
3
]
-
lengths
[
1
])
/
static_cast
<
float
>
(
stride
[
1
])))
+
1
)),
}};
}
else
{
MIGRAPHX_THROW
(
"Invalid padding mode"
);
}
}
};
...
...
@@ -234,10 +297,28 @@ struct leaky_relu
check_shapes
{
inputs
,
*
this
}.
has
(
1
);
return
inputs
.
front
();
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
leaky_relu
&
op
)
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
{
os
<<
op
.
name
()
<<
":"
<<
op
.
alpha
;
return
os
;
return
pack
(
f
(
self
.
alpha
,
"alpha"
));
}
};
struct
elu
{
std
::
string
name
()
const
{
return
"elu"
;
}
float
alpha
;
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
check_shapes
{
inputs
,
*
this
}.
has
(
1
);
return
inputs
.
front
();
}
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
{
return
pack
(
f
(
self
.
alpha
,
"alpha"
));
}
};
...
...
@@ -261,30 +342,36 @@ struct transpose
auto
t
=
input
.
type
();
if
(
dims
.
size
()
!=
input_lens
.
size
())
{
MIGRAPH_THROW
(
"Permutation has wrong number of axes"
);
MIGRAPH
X
_THROW
(
"Permutation has wrong number of axes"
);
}
std
::
vector
<
int64_t
>
axes
(
dims
.
size
());
std
::
iota
(
axes
.
begin
(),
axes
.
end
(),
0
);
if
(
!
std
::
is_permutation
(
axes
.
begin
(),
axes
.
end
(),
dims
.
begin
()))
{
MIGRAPH_THROW
(
"Invalid permutation"
);
MIGRAPH
X
_THROW
(
"Invalid permutation"
);
}
std
::
vector
<
size_t
>
output_lens
(
input_lens
.
size
());
std
::
vector
<
size_t
>
output_strides
(
input_lens
.
size
());
for
(
in
t
i
=
0
;
i
<
output_lens
.
size
();
i
++
)
for
(
std
::
size_
t
i
=
0
;
i
<
output_lens
.
size
();
i
++
)
{
output_lens
[
i
]
=
input_lens
[
dims
[
i
]];
output_strides
[
i
]
=
input_strides
[
dims
[
i
]];
}
return
{
t
,
output_lens
,
output_strides
};
}
argument
compute
(
context
&
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
return
{
std
::
move
(
output_shape
),
std
::
move
(
args
.
front
().
data
)};
}
int
output_alias
(
const
std
::
vector
<
shape
>&
)
const
{
return
0
;
}
};
/// The contiguous operator takes a non-standard input tensor and returns
/// the same tensor but in standard form. For example, if input tensor A which has lens = (4,5)
/// is first transposed, i.e. lens = (5,4), this tensor's data layout remained the same
/// during the transpose operation; only it's shape lengths and strides were changed.
/// This leaves the tensor in a non-standard form. The contiguous operator copies the
/// underlying data such that resulting tensor is returned to a standard form.
struct
contiguous
{
std
::
string
name
()
const
{
return
"contiguous"
;
}
...
...
@@ -295,6 +382,17 @@ struct contiguous
auto
t
=
inputs
.
at
(
0
).
type
();
return
{
t
,
lens
};
}
argument
compute
(
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
assert
(
output_shape
.
standard
());
argument
result
{
output_shape
};
visit_all
(
result
,
args
[
0
])([
&
](
auto
output
,
auto
input
)
{
shape_for_each
(
output
.
get_shape
(),
[
&
](
const
auto
&
idx
)
{
output
(
idx
.
begin
(),
idx
.
end
())
=
input
(
idx
.
begin
(),
idx
.
end
());
});
});
return
result
;
}
};
struct
concat
...
...
@@ -302,7 +400,7 @@ struct concat
std
::
size_t
axis
=
0
;
std
::
string
name
()
const
{
return
"concat"
;
}
std
::
vector
<
std
::
size_t
>
compute_offsets
(
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>
args
)
const
const
std
::
vector
<
argument
>
&
args
)
const
{
std
::
vector
<
std
::
size_t
>
offsets
;
std
::
vector
<
std
::
size_t
>
offset
(
args
[
0
].
get_shape
().
lens
().
size
(),
0
);
...
...
@@ -318,7 +416,7 @@ struct concat
{
if
(
inputs
.
empty
())
{
MIGRAPH_THROW
(
"Number of input tensors should exceed 0"
);
MIGRAPH
X
_THROW
(
"Number of input tensors should exceed 0"
);
}
const
auto
&
first_shape_lens
=
inputs
.
front
().
lens
();
...
...
@@ -331,7 +429,7 @@ struct concat
return
s
.
lens
()[
l
]
==
first_shape_lens
[
l
];
}))
{
MIGRAPH_THROW
(
"Non-axis dimensions should match"
);
MIGRAPH
X
_THROW
(
"Non-axis dimensions should match"
);
}
}
}
...
...
@@ -346,7 +444,27 @@ struct concat
new_lens
[
axis
]
=
new_dim_axis
;
return
{
type
,
new_lens
};
}
int
output_alias
(
const
std
::
vector
<
shape
>&
)
const
{
return
0
;
}
argument
compute
(
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
argument
result
{
output_shape
};
std
::
vector
<
std
::
size_t
>
coffsets
=
compute_offsets
(
output_shape
,
args
);
for
(
std
::
size_t
l
=
0
;
l
<
args
.
size
();
l
++
)
{
auto
argl
=
args
[
l
];
std
::
size_t
nelements
=
argl
.
get_shape
().
elements
();
visit_all
(
result
,
argl
)([
&
](
auto
output
,
auto
input
)
{
auto
slice_shape
=
shape
{
output_shape
.
type
(),
input
.
get_shape
().
lens
(),
output_shape
.
strides
()};
auto
slice
=
make_view
(
slice_shape
,
output
.
data
()
+
coffsets
[
l
]);
// cppcheck-suppress useStlAlgorithm
for
(
std
::
size_t
i
=
0
;
i
<
nelements
;
i
++
)
{
slice
[
i
]
=
input
[
i
];
}
});
}
return
result
;
}
};
struct
slice
...
...
@@ -400,18 +518,9 @@ struct slice
auto
t
=
input_shape
.
type
();
const
auto
&
old_lens
=
input_shape
.
lens
();
const
auto
&
old_strides
=
input_shape
.
strides
();
// std::vector<int64_t> t_axes(old_lens.size());
// if(axes.size() == 0)
// {
// std::iota(t_axes.begin(), t_axes.end(), 0);
// }
// else
// {
// std::copy(axes.begin(), axes.end(), t_axes.begin());
// }
if
(
starts
.
size
()
!=
axes
.
size
()
||
axes
.
size
()
!=
ends
.
size
())
{
MIGRAPH_THROW
(
"inconsistent sizes"
);
MIGRAPH
X
_THROW
(
"inconsistent sizes"
);
}
std
::
vector
<
std
::
size_t
>
new_lens
=
old_lens
;
for
(
std
::
size_t
i
=
0
;
i
<
axes
.
size
();
i
++
)
...
...
@@ -422,7 +531,7 @@ struct slice
}
return
shape
{
t
,
new_lens
,
old_strides
};
}
argument
compute
(
context
&
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
auto
input
=
args
[
0
];
auto
offset
=
compute_offset
(
input
.
get_shape
())
*
output_shape
.
type_size
();
...
...
@@ -450,7 +559,7 @@ struct squeeze
if
(
std
::
any_of
(
axes
.
begin
(),
axes
.
end
(),
[
&
](
auto
axis
)
{
return
input_shape
.
lens
()[
axis
]
!=
1
;
}))
{
MIGRAPH_THROW
(
"squeeze axis dimension should be equal to 1"
);
MIGRAPH
X
_THROW
(
"squeeze axis dimension should be equal to 1"
);
}
std
::
vector
<
std
::
size_t
>
new_lens
;
if
(
axes
.
empty
())
...
...
@@ -472,7 +581,7 @@ struct squeeze
}
return
shape
{
type
,
new_lens
};
}
argument
compute
(
context
&
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
return
{
std
::
move
(
output_shape
),
std
::
move
(
args
.
front
().
data
)};
}
...
...
@@ -511,7 +620,7 @@ struct unsqueeze
}
return
shape
{
type
,
new_lens
};
}
argument
compute
(
context
&
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
return
{
std
::
move
(
output_shape
),
std
::
move
(
args
.
front
().
data
)};
}
...
...
@@ -536,16 +645,21 @@ struct reshape
std
::
vector
<
std
::
size_t
>
rdims
(
dims
.
begin
(),
dims
.
end
());
auto
n_neg_dims
=
std
::
count
(
dims
.
begin
(),
dims
.
end
(),
-
1
);
if
(
n_neg_dims
>
1
)
MIGRAPH_THROW
(
"Dimensions for reshape can only have one -1 dim"
);
MIGRAPH
X
_THROW
(
"Dimensions for reshape can only have one -1 dim"
);
for
(
std
::
size_t
i
=
0
;
i
<
dims
.
size
();
i
++
)
{
if
(
dims
[
i
]
==
0
)
rdims
[
i
]
=
idims
[
i
];
// since rdims using size_t type, -1 is the max value
// is size_t that cause later compuation incorrect
if
(
dims
[
i
]
==
-
1
)
rdims
[
i
]
=
1
;
}
if
(
n_neg_dims
>
0
)
{
size_t
missing_dim
=
-
inputs
.
front
().
elements
()
/
inputs
.
front
().
elements
()
/
std
::
accumulate
(
rdims
.
begin
(),
rdims
.
end
(),
1
,
std
::
multiplies
<
int64_t
>
());
for
(
std
::
size_t
i
=
0
;
i
<
rdims
.
size
();
i
++
)
{
...
...
@@ -553,23 +667,140 @@ struct reshape
rdims
[
i
]
=
missing_dim
;
}
}
if
(
dims
.
back
()
==
-
1
)
shape
s
{
inputs
.
front
().
type
(),
rdims
};
if
(
s
.
elements
()
!=
inputs
.
front
().
elements
())
MIGRAPHX_THROW
(
"Wrong number of elements for reshape"
);
return
s
;
}
argument
compute
(
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
return
{
std
::
move
(
output_shape
),
std
::
move
(
args
.
front
().
data
)};
}
int
output_alias
(
const
std
::
vector
<
shape
>&
)
const
{
return
0
;
}
};
struct
pad
{
std
::
vector
<
int64_t
>
pads
;
float
value
=
0.0
f
;
enum
pad_op_mode_t
{
constant_pad
,
reflect_pad
,
edge_pad
};
pad_op_mode_t
mode
=
constant_pad
;
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
{
return
pack
(
f
(
self
.
mode
,
"mode"
),
f
(
self
.
pads
,
"pads"
),
f
(
self
.
value
,
"value"
));
}
std
::
string
name
()
const
{
return
"pad"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
check_shapes
{
inputs
,
*
this
}.
has
(
1
);
auto
&&
idims
=
inputs
.
front
().
lens
();
std
::
vector
<
std
::
size_t
>
rdims
(
idims
.
begin
(),
idims
.
end
());
std
::
size_t
num_dims
=
rdims
.
size
();
for
(
std
::
size_t
i
=
0
;
i
<
num_dims
;
i
++
)
{
rdims
.
pop_back
();
std
::
copy
(
idims
.
begin
()
+
rdims
.
size
(),
idims
.
end
(),
std
::
back_inserter
(
rdims
));
rdims
[
i
]
+=
pads
[
i
]
+
pads
[
i
+
num_dims
];
}
shape
s
{
inputs
.
front
().
type
(),
rdims
};
if
(
s
.
elements
()
!=
inputs
.
front
().
elements
())
MIGRAPH_THROW
(
"Wrong number of elements for reshape"
);
return
s
;
}
argument
compute
(
context
&
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
};
struct
as_shape
{
shape
s
;
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
{
return
pack
(
f
(
self
.
s
,
"shape"
));
}
std
::
string
name
()
const
{
return
"as_shape"
;
}
shape
compute_shape
(
const
std
::
vector
<
shape
>&
inputs
)
const
{
check_shapes
{
inputs
,
*
this
}.
has
(
1
).
standard
();
assert
(
inputs
.
front
().
elements
()
==
s
.
elements
());
return
s
;
}
argument
compute
(
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
return
{
std
::
move
(
output_shape
),
std
::
move
(
args
.
front
().
data
)};
}
int
output_alias
(
const
std
::
vector
<
shape
>&
)
const
{
return
0
;
}
};
struct
gather
{
int
axis
=
0
;
std
::
string
name
()
const
{
return
"gather"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
check_shapes
{
inputs
,
*
this
}.
has
(
2
);
auto
lens
=
inputs
[
0
].
lens
();
int
n_dim
=
static_cast
<
int
>
(
lens
.
size
());
if
(
axis
>=
n_dim
||
axis
<
-
n_dim
)
{
MIGRAPHX_THROW
(
"Gather: axis is out of range."
);
}
// negative axis means counting dimensions from back
int
axis_index
=
(
axis
<
0
)
?
(
n_dim
+
axis
)
:
axis
;
auto
type
=
inputs
[
0
].
type
();
lens
[
axis_index
]
=
inputs
[
1
].
elements
();
return
{
type
,
lens
};
}
template
<
class
T
>
void
compute_index
(
const
T
&
out_idx
,
const
int
axis_index
,
const
std
::
vector
<
std
::
size_t
>&
vec_indices
,
const
std
::
size_t
max_dim
,
T
&
in_idx
)
const
{
in_idx
=
out_idx
;
std
::
size_t
idx
=
vec_indices
.
at
(
out_idx
[
axis_index
]);
if
(
idx
>=
max_dim
)
{
MIGRAPHX_THROW
(
"Gather: indices are out of range in input tensor"
);
}
in_idx
[
axis_index
]
=
idx
;
}
argument
compute
(
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
argument
result
{
output_shape
};
// negative axis means counting dimensions from back
int
axis_index
=
(
axis
<
0
)
?
(
output_shape
.
lens
().
size
()
+
axis
)
:
axis
;
// max dimension in axis
std
::
size_t
max_dim
=
args
[
0
].
get_shape
().
lens
()[
axis_index
];
std
::
vector
<
std
::
size_t
>
vec_indices
;
args
[
1
].
visit
([
&
](
auto
indices
)
{
vec_indices
.
assign
(
indices
.
begin
(),
indices
.
end
());
});
visit_all
(
result
,
args
[
0
])([
&
](
auto
output
,
auto
input
)
{
std
::
vector
<
std
::
size_t
>
in_idx
;
shape_for_each
(
output
.
get_shape
(),
[
&
](
const
auto
&
idx
)
{
this
->
compute_index
(
idx
,
axis_index
,
vec_indices
,
max_dim
,
in_idx
);
output
(
idx
.
begin
(),
idx
.
end
())
=
input
(
in_idx
.
begin
(),
in_idx
.
end
());
});
});
return
result
;
}
};
struct
dot
{
float
alpha
=
1.0
;
...
...
@@ -590,8 +821,8 @@ struct dot
auto
t
=
a
.
type
();
if
(
a
.
lens
()[
1
]
!=
b
.
lens
()[
0
])
MIGRAPH_THROW
(
"Inner dimensions do not match: {"
+
to_string_range
(
a
.
lens
())
+
"} x {"
+
to_string_range
(
b
.
lens
())
+
"}"
);
MIGRAPH
X
_THROW
(
"Inner dimensions do not match: {"
+
to_string_range
(
a
.
lens
())
+
"} x {"
+
to_string_range
(
b
.
lens
())
+
"}"
);
return
{
t
,
{
a
.
lens
()[
0
],
b
.
lens
()[
1
]}};
}
};
...
...
@@ -609,7 +840,7 @@ struct identity
{
std
::
string
name
()
const
{
return
"identity"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
return
inputs
.
at
(
0
);
}
argument
compute
(
context
&
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
return
{
std
::
move
(
output_shape
),
std
::
move
(
args
.
at
(
0
).
data
)};
}
...
...
@@ -626,6 +857,11 @@ struct exp : unary
std
::
string
name
()
const
{
return
"exp"
;
}
};
struct
log
:
unary
{
std
::
string
name
()
const
{
return
"log"
;
}
};
struct
sin
:
unary
{
std
::
string
name
()
const
{
return
"sin"
;
}
...
...
@@ -656,6 +892,16 @@ struct atan : unary
std
::
string
name
()
const
{
return
"atan"
;
}
};
struct
sinh
:
unary
{
std
::
string
name
()
const
{
return
"sinh"
;
}
};
struct
cosh
:
unary
{
std
::
string
name
()
const
{
return
"cosh"
;
}
};
struct
tanh
:
unary
{
std
::
string
name
()
const
{
return
"tanh"
;
}
...
...
@@ -704,7 +950,7 @@ struct flatten
if
(
axis
>
lens
.
size
())
{
MIGRAPH_THROW
(
"axis for flatten must be less than tensor rank"
);
MIGRAPH
X
_THROW
(
"axis for flatten must be less than tensor rank"
);
}
auto
x
=
std
::
accumulate
(
lens
.
begin
(),
lens
.
begin
()
+
axis
,
std
::
size_t
{
1
},
std
::
multiplies
<>
{});
...
...
@@ -712,12 +958,21 @@ struct flatten
std
::
accumulate
(
lens
.
begin
()
+
axis
,
lens
.
end
(),
std
::
size_t
{
1
},
std
::
multiplies
<>
{});
return
{
inputs
.
at
(
0
).
type
(),
{
x
,
y
}};
}
argument
compute
(
context
&
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
return
{
std
::
move
(
output_shape
),
std
::
move
(
args
.
front
().
data
)};
}
int
output_alias
(
const
std
::
vector
<
shape
>&
)
const
{
return
0
;
}
};
/// The broadcast operator performs the numpy-style broadcasting of an axis of a given tensor. This
/// is achieved primarily by setting the stride of the broadcasted axis to zero. Linear indicies are
/// computed from multi-indicies by computing the inner product on the multi-index with the strides.
/// For example, if we have a tensor A(2,3) it has lengths of (2,3) and strides of (3,1). If we want
/// to compute the linear offset that corresponds to the element on the 2nd row (i = 1) and 3rd
/// column (j = 2), we compute the following inner product (1,2) dot (3, 1) = 1*3 + 2*1 = 5. It is
/// obvious from there that we can negate the effects of a given axis by setting the stride of that
/// axis to zero.
struct
broadcast
{
uint64_t
axis
=
0
;
...
...
@@ -742,7 +997,7 @@ struct broadcast
}))
{
if
(
axis
!=
0
)
MIGRAPH_THROW
(
"when broadcasting tensor of size 1, axis should be 0"
);
MIGRAPH
X
_THROW
(
"when broadcasting tensor of size 1, axis should be 0"
);
return
{
t
,
broadcast_shape
.
lens
(),
std
::
move
(
bcast_strides
)};
}
else
...
...
@@ -750,12 +1005,12 @@ struct broadcast
assert
(
broadcast_shape
.
lens
().
size
()
-
axis
>=
input
.
lens
().
size
());
if
(
!
std
::
equal
(
input
.
lens
().
begin
(),
input
.
lens
().
end
(),
broadcast_shape
.
lens
().
begin
()
+
axis
))
MIGRAPH_THROW
(
"when broadcasting success sizes must match"
);
MIGRAPH
X
_THROW
(
"when broadcasting success sizes must match"
);
std
::
copy
(
input
.
strides
().
begin
(),
input
.
strides
().
end
(),
bcast_strides
.
begin
()
+
axis
);
return
{
t
,
broadcast_shape
.
lens
(),
std
::
move
(
bcast_strides
)};
}
}
argument
compute
(
context
&
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
return
{
std
::
move
(
output_shape
),
std
::
move
(
args
.
at
(
0
).
data
)};
}
...
...
@@ -781,10 +1036,10 @@ struct multibroadcast
auto
input
=
inputs
.
at
(
0
);
if
(
input
.
lens
().
empty
())
MIGRAPH_THROW
(
"inputs dimensions should be > 0"
);
MIGRAPH
X
_THROW
(
"inputs dimensions should be > 0"
);
if
(
input
.
lens
().
size
()
>
output_lens
.
size
())
MIGRAPH_THROW
(
"inputs dimensions should <= output size"
);
MIGRAPH
X
_THROW
(
"inputs dimensions should <= output size"
);
std
::
vector
<
size_t
>
bcast_strides
(
output_lens
.
size
(),
0
);
auto
offset
=
output_lens
.
size
()
-
input
.
lens
().
size
();
...
...
@@ -797,7 +1052,7 @@ struct multibroadcast
}
return
{
t
,
output_lens
,
bcast_strides
};
}
argument
compute
(
context
&
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
return
{
std
::
move
(
output_shape
),
std
::
move
(
args
.
at
(
0
).
data
)};
}
...
...
@@ -813,13 +1068,12 @@ struct scalar
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
assert
(
check_shapes
{
inputs
}.
has
(
1
).
only_dims
(
1
).
size
()
==
1
);
auto
t
=
inputs
.
at
(
0
).
type
();
auto
input
=
inputs
.
at
(
0
);
auto
t
=
inputs
.
at
(
0
).
type
();
std
::
vector
<
std
::
size_t
>
strides
(
scalar_bcast
.
lens
().
size
(),
0
);
return
{
t
,
scalar_bcast
.
lens
(),
strides
};
}
argument
compute
(
context
&
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
return
{
std
::
move
(
output_shape
),
std
::
move
(
args
.
at
(
0
).
data
)};
}
...
...
@@ -857,6 +1111,16 @@ struct div : binary
std
::
string
name
()
const
{
return
"div"
;
}
};
struct
max
:
binary
{
std
::
string
name
()
const
{
return
"max"
;
}
};
struct
min
:
binary
{
std
::
string
name
()
const
{
return
"min"
;
}
};
struct
load
{
shape
s
;
...
...
@@ -874,7 +1138,7 @@ struct load
check_shapes
{
inputs
}.
has
(
1
);
return
s
;
}
argument
compute
(
context
&
,
const
shape
&
,
const
std
::
vector
<
argument
>&
args
)
const
argument
compute
(
const
shape
&
,
const
std
::
vector
<
argument
>&
args
)
const
{
return
{
s
,
args
[
0
].
data
()
+
offset
};
}
...
...
@@ -897,14 +1161,118 @@ struct outline
check_shapes
{
inputs
,
*
this
}.
has
(
0
);
return
s
;
}
argument
compute
(
context
&
,
const
shape
&
,
const
std
::
vector
<
argument
>&
)
const
argument
compute
(
const
shape
&
,
const
std
::
vector
<
argument
>&
)
const
{
return
{
s
,
nullptr
};
}
};
// indicate rnn computation direction
enum
class
rnn_direction
{
forward
,
reverse
,
bidirectional
,
};
struct
rnn
{
std
::
size_t
hidden_size
=
1
;
std
::
vector
<
operation
>
actv_funcs
{
tanh
{},
tanh
{}};
rnn_direction
direction
=
rnn_direction
::
forward
;
float
clip
=
0.0
f
;
std
::
string
name
()
const
{
return
"rnn"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
return
{
s
,
nullptr
};
auto
in_dims
=
inputs
[
0
].
lens
();
auto
hidden_dims
=
inputs
[
2
].
lens
();
if
(
hidden_size
!=
hidden_dims
[
2
])
{
MIGRAPHX_THROW
(
"RNN: hidden size mismatch in attribute and input"
);
}
std
::
size_t
num_directions
=
1
;
if
(
direction
==
rnn_direction
::
bidirectional
)
{
num_directions
=
2
;
}
if
(
num_directions
!=
hidden_dims
[
0
])
{
MIGRAPHX_THROW
(
"RNN: num_direction mismatch in attribute and input"
);
}
std
::
vector
<
std
::
size_t
>
out_dims
(
in_dims
);
out_dims
.
insert
(
out_dims
.
begin
()
+
1
,
num_directions
);
out_dims
.
back
()
=
hidden_size
;
return
{
inputs
[
0
].
type
(),
out_dims
};
}
};
struct
rnn_last_output
{
std
::
string
name
()
const
{
return
"rnn_last_output"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
check_shapes
{
inputs
,
*
this
}.
has
(
1
);
auto
dims
=
inputs
[
0
].
lens
();
// remove the first dimension, remaing are output shape
dims
.
erase
(
dims
.
begin
());
return
{
inputs
[
0
].
type
(),
dims
};
}
};
struct
gru
{
std
::
size_t
hidden_size
=
1
;
std
::
vector
<
operation
>
actv_funcs
{
sigmoid
{},
tanh
{}};
rnn_direction
direction
=
rnn_direction
::
forward
;
float
clip
=
0.0
f
;
int
linear_before_reset
=
0
;
std
::
string
name
()
const
{
return
"gru"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
auto
in_dims
=
inputs
[
0
].
lens
();
auto
hidden_dims
=
inputs
[
2
].
lens
();
if
(
hidden_size
!=
hidden_dims
[
2
])
{
MIGRAPHX_THROW
(
"GRU: hidden size mismatch in attribute and input"
);
}
std
::
size_t
num_directions
=
1
;
if
(
direction
==
rnn_direction
::
bidirectional
)
{
num_directions
=
2
;
}
if
(
num_directions
!=
hidden_dims
[
0
])
{
MIGRAPHX_THROW
(
"GRU: num_direction does not match the direction attribute"
);
}
std
::
vector
<
std
::
size_t
>
out_dims
(
in_dims
);
out_dims
.
insert
(
out_dims
.
begin
()
+
1
,
num_directions
);
out_dims
.
back
()
=
hidden_size
;
return
{
inputs
[
0
].
type
(),
out_dims
};
}
};
struct
undefined
{
std
::
string
name
()
const
{
return
"undefined"
;
}
shape
compute_shape
(
const
std
::
vector
<
shape
>&
inputs
)
const
{
check_shapes
{
inputs
,
*
this
}.
has
(
0
);
return
{};
}
argument
compute
(
const
shape
&
,
const
std
::
vector
<
argument
>&
)
const
{
return
{{},
nullptr
};
}
};
}
// namespace op
}
// namespace MIGRAPH_INLINE_NS
}
// namespace MIGRAPH
X
_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/par_dfor.hpp
0 → 100644
View file @
a5c1c7f6
#ifndef MIGRAPHX_GUARD_RTGLIB_PAR_DFOR_HPP
#define MIGRAPHX_GUARD_RTGLIB_PAR_DFOR_HPP
#include <migraphx/par_for.hpp>
#include <migraphx/functional.hpp>
#include <array>
#include <numeric>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
template
<
class
...
Ts
>
auto
par_dfor
(
Ts
...
xs
)
{
return
[
=
](
auto
f
)
{
using
array_type
=
std
::
array
<
std
::
size_t
,
sizeof
...(
Ts
)
>
;
array_type
lens
=
{{
static_cast
<
std
::
size_t
>
(
xs
)...}};
auto
n
=
std
::
accumulate
(
lens
.
begin
(),
lens
.
end
(),
1
,
std
::
multiplies
<
std
::
size_t
>
{});
const
std
::
size_t
min_grain
=
8
;
if
(
n
>
2
*
min_grain
)
{
array_type
strides
;
strides
.
fill
(
1
);
std
::
partial_sum
(
lens
.
rbegin
(),
lens
.
rend
()
-
1
,
strides
.
rbegin
()
+
1
,
std
::
multiplies
<
std
::
size_t
>
());
auto
size
=
std
::
accumulate
(
lens
.
begin
(),
lens
.
end
(),
1
,
std
::
multiplies
<
std
::
size_t
>
());
par_for
(
size
,
min_grain
,
[
&
](
std
::
size_t
i
)
{
array_type
indices
;
std
::
transform
(
strides
.
begin
(),
strides
.
end
(),
lens
.
begin
(),
indices
.
begin
(),
[
&
](
size_t
stride
,
size_t
len
)
{
return
(
i
/
stride
)
%
len
;
});
migraphx
::
unpack
(
f
,
indices
);
});
}
else
{
dfor
(
xs
...)(
f
);
}
};
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
Prev
1
2
3
4
5
6
7
…
16
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment