Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
a5c1c7f6
"test/torchaudio_unittest/compliance_kaldi_test.py" did not exist on "08a71271b2c778fb79cc47b8213203432bf1dc1e"
Unverified
Commit
a5c1c7f6
authored
Feb 10, 2019
by
Paul Fultz II
Committed by
GitHub
Feb 10, 2019
Browse files
Merge branch 'develop' into mem_color_ordering_fix
parents
462a4920
d516b099
Changes
303
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
284 additions
and
108 deletions
+284
-108
src/include/migraphx/par_for.hpp
src/include/migraphx/par_for.hpp
+82
-0
src/include/migraphx/pass.hpp
src/include/migraphx/pass.hpp
+12
-6
src/include/migraphx/pass_config.hpp
src/include/migraphx/pass_config.hpp
+6
-6
src/include/migraphx/program.hpp
src/include/migraphx/program.hpp
+10
-4
src/include/migraphx/ranges.hpp
src/include/migraphx/ranges.hpp
+4
-4
src/include/migraphx/rank.hpp
src/include/migraphx/rank.hpp
+4
-4
src/include/migraphx/raw_data.hpp
src/include/migraphx/raw_data.hpp
+10
-10
src/include/migraphx/reflect.hpp
src/include/migraphx/reflect.hpp
+4
-4
src/include/migraphx/requires.hpp
src/include/migraphx/requires.hpp
+14
-14
src/include/migraphx/rewrite_rnn.hpp
src/include/migraphx/rewrite_rnn.hpp
+53
-0
src/include/migraphx/shape.hpp
src/include/migraphx/shape.hpp
+38
-15
src/include/migraphx/shape_for_each.hpp
src/include/migraphx/shape_for_each.hpp
+4
-4
src/include/migraphx/simplify_algebra.hpp
src/include/migraphx/simplify_algebra.hpp
+4
-4
src/include/migraphx/simplify_reshapes.hpp
src/include/migraphx/simplify_reshapes.hpp
+4
-4
src/include/migraphx/streamutils.hpp
src/include/migraphx/streamutils.hpp
+4
-4
src/include/migraphx/stringutils.hpp
src/include/migraphx/stringutils.hpp
+4
-4
src/include/migraphx/target.hpp
src/include/migraphx/target.hpp
+10
-4
src/include/migraphx/tensor_view.hpp
src/include/migraphx/tensor_view.hpp
+9
-9
src/include/migraphx/time.hpp
src/include/migraphx/time.hpp
+4
-4
src/include/migraphx/tracer.hpp
src/include/migraphx/tracer.hpp
+4
-4
No files found.
src/include/migraphx/par_for.hpp
0 → 100644
View file @
a5c1c7f6
#ifndef MIGRAPHX_GUARD_RTGLIB_PAR_FOR_HPP
#define MIGRAPHX_GUARD_RTGLIB_PAR_FOR_HPP
#include <thread>
#include <cmath>
#include <algorithm>
#include <vector>
#include <cassert>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
struct
joinable_thread
:
std
::
thread
{
template
<
class
...
Xs
>
joinable_thread
(
Xs
&&
...
xs
)
:
std
::
thread
(
std
::
forward
<
Xs
>
(
xs
)...)
// NOLINT
{
}
joinable_thread
&
operator
=
(
joinable_thread
&&
other
)
=
default
;
joinable_thread
(
joinable_thread
&&
other
)
=
default
;
~
joinable_thread
()
{
if
(
this
->
joinable
())
this
->
join
();
}
};
template
<
class
F
>
void
par_for_impl
(
std
::
size_t
n
,
std
::
size_t
threadsize
,
F
f
)
{
if
(
threadsize
<=
1
)
{
for
(
std
::
size_t
i
=
0
;
i
<
n
;
i
++
)
f
(
i
);
}
else
{
std
::
vector
<
joinable_thread
>
threads
(
threadsize
);
// Using const here causes gcc 5 to ICE
#if(!defined(__GNUC__) || __GNUC__ != 5)
const
#endif
std
::
size_t
grainsize
=
std
::
ceil
(
static_cast
<
double
>
(
n
)
/
threads
.
size
());
std
::
size_t
work
=
0
;
std
::
generate
(
threads
.
begin
(),
threads
.
end
(),
[
=
,
&
work
]
{
auto
result
=
joinable_thread
([
=
]
{
std
::
size_t
start
=
work
;
std
::
size_t
last
=
std
::
min
(
n
,
work
+
grainsize
);
for
(
std
::
size_t
i
=
start
;
i
<
last
;
i
++
)
{
f
(
i
);
}
});
work
+=
grainsize
;
return
result
;
});
assert
(
work
>=
n
);
}
}
template
<
class
F
>
void
par_for
(
std
::
size_t
n
,
std
::
size_t
min_grain
,
F
f
)
{
const
auto
threadsize
=
std
::
min
<
std
::
size_t
>
(
std
::
thread
::
hardware_concurrency
(),
n
/
min_grain
);
par_for_impl
(
n
,
threadsize
,
f
);
}
template
<
class
F
>
void
par_for
(
std
::
size_t
n
,
F
f
)
{
const
int
min_grain
=
8
;
par_for
(
n
,
min_grain
,
f
);
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/pass.hpp
View file @
a5c1c7f6
#ifndef MIGRAPH_GUARD_PASS_HPP
#define MIGRAPH_GUARD_PASS_HPP
#ifndef MIGRAPH
X
_GUARD_PASS_HPP
#define MIGRAPH
X
_GUARD_PASS_HPP
#include <cassert>
#include <string>
...
...
@@ -10,7 +10,7 @@
#include <migraphx/config.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPH_INLINE_NS
{
inline
namespace
MIGRAPH
X
_INLINE_NS
{
struct
program
;
...
...
@@ -105,7 +105,13 @@ struct pass
void
apply
(
program
&
p
)
const
{
assert
((
*
this
).
private_detail_te_handle_mem_var
);
return
(
*
this
).
private_detail_te_get_handle
().
apply
(
p
);
(
*
this
).
private_detail_te_get_handle
().
apply
(
p
);
}
friend
bool
is_shared
(
const
pass
&
private_detail_x
,
const
pass
&
private_detail_y
)
{
return
private_detail_x
.
private_detail_te_handle_mem_var
==
private_detail_y
.
private_detail_te_handle_mem_var
;
}
private:
...
...
@@ -149,7 +155,7 @@ struct pass
std
::
string
name
()
const
override
{
return
private_detail_te_value
.
name
();
}
void
apply
(
program
&
p
)
const
override
{
return
private_detail_te_value
.
apply
(
p
);
}
void
apply
(
program
&
p
)
const
override
{
private_detail_te_value
.
apply
(
p
);
}
PrivateDetailTypeErasedT
private_detail_te_value
;
};
...
...
@@ -218,7 +224,7 @@ inline const ValueType& any_cast(const pass& x)
#endif
}
// namespace MIGRAPH_INLINE_NS
}
// namespace MIGRAPH
X
_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/pass_config.hpp
View file @
a5c1c7f6
#ifndef MIGRAPH_GUARD_PASS_CONFIG_HPP
#define MIGRAPH_GUARD_PASS_CONFIG_HPP
#ifndef MIGRAPH
X
_GUARD_PASS_CONFIG_HPP
#define MIGRAPH
X
_GUARD_PASS_CONFIG_HPP
#include <migraphx/env.hpp>
#include <migraphx/config.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPH_INLINE_NS
{
inline
namespace
MIGRAPH
X
_INLINE_NS
{
MIGRAPH_DECLARE_ENV_VAR
(
MIGRAPH_DISABLE_MEMORY_COLORING
)
MIGRAPH
X
_DECLARE_ENV_VAR
(
MIGRAPH
X
_DISABLE_MEMORY_COLORING
)
}
// namespace MIGRAPH_INLINE_NS
}
// namespace MIGRAPH
X
_INLINE_NS
}
// namespace migraphx
#endif // MIGRAPH_GUARD_PASS_CONFIG_HPP
#endif // MIGRAPH
X
_GUARD_PASS_CONFIG_HPP
src/include/migraphx/program.hpp
View file @
a5c1c7f6
#ifndef MIGRAPH_GUARD_MIGRAPHLIB_PROGRAM_HPP
#define MIGRAPH_GUARD_MIGRAPHLIB_PROGRAM_HPP
#ifndef MIGRAPH
X
_GUARD_MIGRAPHLIB_PROGRAM_HPP
#define MIGRAPH
X
_GUARD_MIGRAPHLIB_PROGRAM_HPP
#include <list>
#include <unordered_map>
...
...
@@ -14,7 +14,7 @@
#include <iostream>
namespace
migraphx
{
inline
namespace
MIGRAPH_INLINE_NS
{
inline
namespace
MIGRAPH
X
_INLINE_NS
{
struct
program_impl
;
...
...
@@ -91,16 +91,22 @@ struct program
shape
get_shape
()
const
;
context
&
get_context
()
const
;
instruction_ref
validate
()
const
;
void
compile
(
const
target
&
t
,
tracer
trace
=
tracer
{});
void
finalize
();
void
perf_report
(
std
::
ostream
&
os
,
std
::
size_t
n
,
parameter_map
params
)
const
;
void
debug_print
()
const
;
void
debug_print
(
instruction_ref
ins
)
const
;
void
debug_print
(
const
std
::
vector
<
instruction_ref
>&
inss
)
const
;
void
dry_run
(
parameter_map
params
)
const
;
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
program
&
p
);
friend
bool
operator
==
(
const
program
&
x
,
const
program
&
y
);
friend
bool
operator
!=
(
const
program
&
x
,
const
program
&
y
)
{
return
!
(
x
==
y
);
}
...
...
@@ -109,7 +115,7 @@ struct program
std
::
unique_ptr
<
program_impl
>
impl
;
};
}
// namespace MIGRAPH_INLINE_NS
}
// namespace MIGRAPH
X
_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/ranges.hpp
View file @
a5c1c7f6
#ifndef MIGRAPH_GUARD_MIGRAPHLIB_RANGES_HPP
#define MIGRAPH_GUARD_MIGRAPHLIB_RANGES_HPP
#ifndef MIGRAPH
X
_GUARD_MIGRAPHLIB_RANGES_HPP
#define MIGRAPH
X
_GUARD_MIGRAPHLIB_RANGES_HPP
#include <algorithm>
#include <initializer_list>
...
...
@@ -7,7 +7,7 @@
#include <migraphx/config.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPH_INLINE_NS
{
inline
namespace
MIGRAPH
X
_INLINE_NS
{
namespace
detail
{
...
...
@@ -106,7 +106,7 @@ iterator_range<Iterator> range(std::pair<Iterator, Iterator> p)
return
{
p
.
first
,
p
.
second
};
}
}
// namespace MIGRAPH_INLINE_NS
}
// namespace MIGRAPH
X
_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/rank.hpp
View file @
a5c1c7f6
#ifndef MIGRAPH_GUARD_RTGLIB_RANK_HPP
#define MIGRAPH_GUARD_RTGLIB_RANK_HPP
#ifndef MIGRAPH
X
_GUARD_RTGLIB_RANK_HPP
#define MIGRAPH
X
_GUARD_RTGLIB_RANK_HPP
#include <migraphx/config.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPH_INLINE_NS
{
inline
namespace
MIGRAPH
X
_INLINE_NS
{
template
<
int
N
>
struct
rank
:
rank
<
N
-
1
>
...
...
@@ -16,7 +16,7 @@ struct rank<0>
{
};
}
// namespace MIGRAPH_INLINE_NS
}
// namespace MIGRAPH
X
_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/raw_data.hpp
View file @
a5c1c7f6
#ifndef MIGRAPH_GUARD_RAW_DATA_HPP
#define MIGRAPH_GUARD_RAW_DATA_HPP
#ifndef MIGRAPH
X
_GUARD_RAW_DATA_HPP
#define MIGRAPH
X
_GUARD_RAW_DATA_HPP
#include <migraphx/tensor_view.hpp>
#include <migraphx/requires.hpp>
#include <migraphx/config.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPH_INLINE_NS
{
inline
namespace
MIGRAPH
X
_INLINE_NS
{
struct
raw_data_base
{
...
...
@@ -126,7 +126,7 @@ struct raw_data : raw_data_base
auto
&&
s
=
static_cast
<
const
Derived
&>
(
*
this
).
get_shape
();
auto
&&
buffer
=
static_cast
<
const
Derived
&>
(
*
this
).
data
();
if
(
s
.
type
()
!=
migraphx
::
shape
::
get_type
<
T
>
{})
MIGRAPH_THROW
(
"Incorrect data type for raw data"
);
MIGRAPH
X
_THROW
(
"Incorrect data type for raw data"
);
return
make_view
(
s
,
reinterpret_cast
<
T
*>
(
buffer
));
}
...
...
@@ -143,7 +143,7 @@ struct raw_data : raw_data_base
template
<
class
T
,
class
U
,
MIGRAPH_REQUIRES
(
std
::
is_base_of
<
raw_data_base
,
T
>{}
&&
MIGRAPH
X
_REQUIRES
(
std
::
is_base_of
<
raw_data_base
,
T
>{}
&&
std
::
is_base_of
<
raw_data_base
,
U
>
{})
>
bool
operator
==
(
const
T
&
x
,
const
U
&
y
)
{
...
...
@@ -166,7 +166,7 @@ bool operator==(const T& x, const U& y)
template
<
class
T
,
class
U
,
MIGRAPH_REQUIRES
(
std
::
is_base_of
<
raw_data_base
,
T
>{}
&&
MIGRAPH
X
_REQUIRES
(
std
::
is_base_of
<
raw_data_base
,
T
>{}
&&
std
::
is_base_of
<
raw_data_base
,
U
>
{})
>
bool
operator
!=
(
const
T
&
x
,
const
U
&
y
)
{
...
...
@@ -198,14 +198,14 @@ auto visit_all(T&& x, Ts&&... xs)
auto
&&
s
=
x
.
get_shape
();
std
::
initializer_list
<
shape
::
type_t
>
types
=
{
xs
.
get_shape
().
type
()...};
if
(
!
std
::
all_of
(
types
.
begin
(),
types
.
end
(),
[
&
](
shape
::
type_t
t
)
{
return
t
==
s
.
type
();
}))
MIGRAPH_THROW
(
"Types must be the same"
);
MIGRAPH
X
_THROW
(
"Types must be the same"
);
return
[
&
](
auto
v
)
{
// Workaround for https://gcc.gnu.org/bugzilla/show_bug.cgi?id=70100
detail
::
visit_all_impl
(
s
,
v
,
x
,
xs
...);
};
}
}
// namespace MIGRAPH_INLINE_NS
}
// namespace MIGRAPH
X
_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/reflect.hpp
View file @
a5c1c7f6
#ifndef MIGRAPH_GUARD_RTGLIB_REFLECT_HPP
#define MIGRAPH_GUARD_RTGLIB_REFLECT_HPP
#ifndef MIGRAPH
X
_GUARD_RTGLIB_REFLECT_HPP
#define MIGRAPH
X
_GUARD_RTGLIB_REFLECT_HPP
#include <migraphx/functional.hpp>
#include <migraphx/rank.hpp>
...
...
@@ -7,7 +7,7 @@
#include <functional>
namespace
migraphx
{
inline
namespace
MIGRAPH_INLINE_NS
{
inline
namespace
MIGRAPH
X
_INLINE_NS
{
namespace
detail
{
...
...
@@ -47,7 +47,7 @@ void reflect_each(T& x, F f)
});
}
}
// namespace MIGRAPH_INLINE_NS
}
// namespace MIGRAPH
X
_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/requires.hpp
View file @
a5c1c7f6
#ifndef MIGRAPH_GUARD_MIGRAPHLIB_REQUIRES_HPP
#define MIGRAPH_GUARD_MIGRAPHLIB_REQUIRES_HPP
#ifndef MIGRAPH
X
_GUARD_MIGRAPHLIB_REQUIRES_HPP
#define MIGRAPH
X
_GUARD_MIGRAPHLIB_REQUIRES_HPP
#include <type_traits>
#include <migraphx/config.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPH_INLINE_NS
{
inline
namespace
MIGRAPH
X
_INLINE_NS
{
template
<
bool
...
Bs
>
struct
and_
:
std
::
is_same
<
and_
<
Bs
...
>
,
and_
<
(
Bs
||
true
)...
>>
// NOLINT
...
...
@@ -24,29 +24,29 @@ struct requires_enum
};
};
#define MIGRAPH_REQUIRES_CAT(x, y) x##y
#define MIGRAPH
X
_REQUIRES_CAT(x, y) x##y
#ifdef CPPCHECK
#define MIGRAPH_REQUIRES(...) class = void
#define MIGRAPH
X
_REQUIRES(...) class = void
#else
#if 0
// TODO: This currently crashed on clang
#define MIGRAPH_REQUIRES(...) \
typename migraphx::requires_enum<__LINE__>::e MIGRAPH_REQUIRES_CAT( \
#define MIGRAPH
X
_REQUIRES(...) \
typename migraphx::requires_enum<__LINE__>::e MIGRAPH
X
_REQUIRES_CAT( \
PrivateRequires, \
__LINE__) = migraphx::requires_enum<__LINE__>::a, \
class = typename std::enable_if<and_<__VA_ARGS__, \
MIGRAPH_REQUIRES_CAT(PrivateRequires, __LINE__) == \
MIGRAPH
X
_REQUIRES_CAT(PrivateRequires, __LINE__) == \
migraphx::requires_enum<__LINE__>::a>{}>::type
#else
#define MIGRAPH_REQUIRES(...)
\
typename migraphx::requires_enum<__LINE__>::e MIGRAPH_REQUIRES_CAT(
\
#define MIGRAPH
X
_REQUIRES(...) \
typename migraphx::requires_enum<__LINE__>::e MIGRAPH
X
_REQUIRES_CAT( \
PrivateRequires, __LINE__) = migraphx::requires_enum<__LINE__>::a, \
class = typename std::enable_if<and_<__VA_ARGS__>{}>::type
#endif
#endif
}
// namespace MIGRAPH_INLINE_NS
}
// namespace MIGRAPH
X
_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/rewrite_rnn.hpp
0 → 100644
View file @
a5c1c7f6
#ifndef MIGRAPHX_GUARD_RTGLIB_REWRITE_RNN_HPP
#define MIGRAPHX_GUARD_RTGLIB_REWRITE_RNN_HPP
#include <string>
#include <vector>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/config.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
struct
program
;
/**
* Rewrite rnn to gemm and add.
*/
struct
rewrite_rnn
{
std
::
string
name
()
const
{
return
"rewrite_rnn"
;
}
void
apply
(
program
&
prog
)
const
;
private:
// for vanilla rnn operators
void
apply_vanilla_rnn
(
program
&
prog
,
instruction_ref
ins
)
const
;
std
::
vector
<
instruction_ref
>
vanilla_rnn_cell
(
bool
is_forward
,
program
&
prog
,
instruction_ref
ins
,
instruction_ref
input
,
instruction_ref
w
,
instruction_ref
r
,
instruction_ref
bias
,
instruction_ref
ih
,
operation
&
actv_func
)
const
;
std
::
vector
<
operation
>
vanilla_rnn_actv_funcs
(
instruction_ref
ins
)
const
;
// for gru operators
void
apply_gru
(
program
&
prog
,
instruction_ref
ins
)
const
;
std
::
vector
<
instruction_ref
>
gru_cell
(
bool
is_forward
,
program
&
prog
,
instruction_ref
ins
,
std
::
vector
<
instruction_ref
>
inputs
,
int
linear_before_reset
,
const
operation
&
actv_func1
,
const
operation
&
actv_func2
)
const
;
std
::
vector
<
operation
>
gru_actv_funcs
(
instruction_ref
ins
)
const
;
};
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/shape.hpp
View file @
a5c1c7f6
#ifndef MIGRAPH_GUARD_MIGRAPHLIB_SHAPE_HPP
#define MIGRAPH_GUARD_MIGRAPHLIB_SHAPE_HPP
#ifndef MIGRAPH
X
_GUARD_MIGRAPHLIB_SHAPE_HPP
#define MIGRAPH
X
_GUARD_MIGRAPHLIB_SHAPE_HPP
#include <vector>
#include <cassert>
...
...
@@ -12,7 +12,7 @@
#include <migraphx/config.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPH_INLINE_NS
{
inline
namespace
MIGRAPH
X
_INLINE_NS
{
struct
shape_impl
;
...
...
@@ -21,7 +21,7 @@ struct shape
// Add new types here
// clang-format off
#define MIGRAPH_SHAPE_VISIT_TYPES(m) \
#define MIGRAPH
X
_SHAPE_VISIT_TYPES(m) \
m(half_type, half) \
m(float_type, float) \
m(double_type, double) \
...
...
@@ -35,22 +35,22 @@ struct shape
m(uint64_type, uint64_t)
// clang-format on
#define MIGRAPH_SHAPE_ENUM_TYPES(x, t) x,
#define MIGRAPH
X
_SHAPE_
GENERATE_
ENUM_TYPES(x, t) x,
enum
type_t
{
MIGRAPH_SHAPE_VISIT_TYPES
(
MIGRAPH_SHAPE_ENUM_TYPES
)
MIGRAPH
X
_SHAPE_VISIT_TYPES
(
MIGRAPH
X
_SHAPE_
GENERATE_
ENUM_TYPES
)
};
#undef MIGRAPH_SHAPE_ENUM_TYPES
#undef MIGRAPH
X
_SHAPE_
GENERATE_
ENUM_TYPES
template
<
class
T
,
class
=
void
>
struct
get_type
;
#define MIGRAPH_SHAPE_GET_TYPE(x, t)
\
#define MIGRAPH
X
_SHAPE_
GENERATE_
GET_TYPE(x, t) \
template <class T> \
struct get_type<t, T> : std::integral_constant<type_t, x> \
{ \
};
MIGRAPH_SHAPE_VISIT_TYPES
(
MIGRAPH_SHAPE_GET_TYPE
)
#undef MIGRAPH_SHAPE_GET_TYPE
MIGRAPH
X
_SHAPE_VISIT_TYPES
(
MIGRAPH
X
_SHAPE_
GENERATE_
GET_TYPE
)
#undef MIGRAPH
X
_SHAPE_
GENERATE_
GET_TYPE
template
<
class
T
>
struct
get_type
<
const
T
>
:
get_type
<
T
>
...
...
@@ -62,6 +62,19 @@ struct shape
shape
(
type_t
t
,
std
::
vector
<
std
::
size_t
>
l
);
shape
(
type_t
t
,
std
::
vector
<
std
::
size_t
>
l
,
std
::
vector
<
std
::
size_t
>
s
);
template
<
class
Range
>
shape
(
type_t
t
,
const
Range
&
l
)
:
shape
(
t
,
std
::
vector
<
std
::
size_t
>
(
l
.
begin
(),
l
.
end
()))
{
}
template
<
class
Range1
,
class
Range2
>
shape
(
type_t
t
,
const
Range1
&
l
,
const
Range2
&
s
)
:
shape
(
t
,
std
::
vector
<
std
::
size_t
>
(
l
.
begin
(),
l
.
end
()),
std
::
vector
<
std
::
size_t
>
(
s
.
begin
(),
s
.
end
()))
{
}
type_t
type
()
const
;
const
std
::
vector
<
std
::
size_t
>&
lens
()
const
;
const
std
::
vector
<
std
::
size_t
>&
strides
()
const
;
...
...
@@ -141,6 +154,8 @@ struct shape
{
return
reinterpret_cast
<
const
T
*>
(
buffer
)
+
n
;
}
type_t
type_enum
()
const
{
return
get_type
<
T
>
{};
}
};
template
<
class
Visitor
>
...
...
@@ -148,12 +163,20 @@ struct shape
{
switch
(
this
->
type
())
{
#define MIGRAPH_SHAPE_VISITOR_CASE(x, t) \
#define MIGRAPH
X
_SHAPE_
GENERATE_
VISITOR_CASE(x, t) \
case x: v(as<t>()); return;
MIGRAPH_SHAPE_VISIT_TYPES
(
MIGRAPH_SHAPE_VISITOR_CASE
)
#undef MIGRAPH_SHAPE_VISITOR_CASE
MIGRAPHX_SHAPE_VISIT_TYPES
(
MIGRAPHX_SHAPE_GENERATE_VISITOR_CASE
)
#undef MIGRAPHX_SHAPE_GENERATE_VISITOR_CASE
}
MIGRAPHX_THROW
(
"Unknown type"
);
}
MIGRAPH_THROW
(
"Unknown type"
);
template
<
class
Visitor
>
static
void
visit_types
(
Visitor
v
)
{
#define MIGRAPHX_SHAPE_GENERATE_VISITOR_ALL(x, t) v(as<t>());
MIGRAPHX_SHAPE_VISIT_TYPES
(
MIGRAPHX_SHAPE_GENERATE_VISITOR_ALL
)
#undef MIGRAPHX_SHAPE_GENERATE_VISITOR_ALL
}
private:
...
...
@@ -163,7 +186,7 @@ struct shape
std
::
string
type_string
()
const
;
};
}
// namespace MIGRAPH_INLINE_NS
}
// namespace MIGRAPH
X
_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/shape_for_each.hpp
View file @
a5c1c7f6
#ifndef MIGRAPH_GUARD_MIGRAPHLIB_SHAPE_FOR_EACH_HPP
#define MIGRAPH_GUARD_MIGRAPHLIB_SHAPE_FOR_EACH_HPP
#ifndef MIGRAPH
X
_GUARD_MIGRAPHLIB_SHAPE_FOR_EACH_HPP
#define MIGRAPH
X
_GUARD_MIGRAPHLIB_SHAPE_FOR_EACH_HPP
#include <migraphx/shape.hpp>
#include <migraphx/config.hpp>
#include <algorithm>
namespace
migraphx
{
inline
namespace
MIGRAPH_INLINE_NS
{
inline
namespace
MIGRAPH
X
_INLINE_NS
{
template
<
class
F
>
void
shape_for_each
(
const
migraphx
::
shape
&
s
,
F
f
)
...
...
@@ -28,7 +28,7 @@ void shape_for_each(const migraphx::shape& s, F f)
}
}
}
// namespace MIGRAPH_INLINE_NS
}
// namespace MIGRAPH
X
_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/simplify_algebra.hpp
View file @
a5c1c7f6
#ifndef MIGRAPH_GUARD_RTGLIB_SIMPLIFY_ALGEBRA_HPP
#define MIGRAPH_GUARD_RTGLIB_SIMPLIFY_ALGEBRA_HPP
#ifndef MIGRAPH
X
_GUARD_RTGLIB_SIMPLIFY_ALGEBRA_HPP
#define MIGRAPH
X
_GUARD_RTGLIB_SIMPLIFY_ALGEBRA_HPP
#include <string>
#include <migraphx/config.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPH_INLINE_NS
{
inline
namespace
MIGRAPH
X
_INLINE_NS
{
struct
program
;
...
...
@@ -18,7 +18,7 @@ struct simplify_algebra
void
apply
(
program
&
p
)
const
;
};
}
// namespace MIGRAPH_INLINE_NS
}
// namespace MIGRAPH
X
_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/simplify_reshapes.hpp
View file @
a5c1c7f6
#ifndef MIGRAPH_GUARD_RTGLIB_SIMPLIFY_RESHAPES_HPP
#define MIGRAPH_GUARD_RTGLIB_SIMPLIFY_RESHAPES_HPP
#ifndef MIGRAPH
X
_GUARD_RTGLIB_SIMPLIFY_RESHAPES_HPP
#define MIGRAPH
X
_GUARD_RTGLIB_SIMPLIFY_RESHAPES_HPP
#include <string>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/config.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPH_INLINE_NS
{
inline
namespace
MIGRAPH
X
_INLINE_NS
{
struct
program
;
...
...
@@ -19,7 +19,7 @@ struct simplify_reshapes
void
apply
(
program
&
p
)
const
;
};
}
// namespace MIGRAPH_INLINE_NS
}
// namespace MIGRAPH
X
_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/streamutils.hpp
View file @
a5c1c7f6
#ifndef MIGRAPH_GUARD_STREAMUTILS_HPP
#define MIGRAPH_GUARD_STREAMUTILS_HPP
#ifndef MIGRAPH
X
_GUARD_STREAMUTILS_HPP
#define MIGRAPH
X
_GUARD_STREAMUTILS_HPP
#include <ostream>
#include <algorithm>
...
...
@@ -7,7 +7,7 @@
#include <migraphx/config.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPH_INLINE_NS
{
inline
namespace
MIGRAPH
X
_INLINE_NS
{
template
<
class
T
>
struct
stream_range_container
...
...
@@ -56,7 +56,7 @@ void stream_write_value(std::ostream& os, const T& x)
detail
::
stream_write_value_impl
(
rank
<
1
>
{},
os
,
x
);
}
}
// namespace MIGRAPH_INLINE_NS
}
// namespace MIGRAPH
X
_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/stringutils.hpp
View file @
a5c1c7f6
#ifndef MIGRAPH_GUARD_MIGRAPHLIB_STRINGUTILS_HPP
#define MIGRAPH_GUARD_MIGRAPHLIB_STRINGUTILS_HPP
#ifndef MIGRAPH
X
_GUARD_MIGRAPHLIB_STRINGUTILS_HPP
#define MIGRAPH
X
_GUARD_MIGRAPHLIB_STRINGUTILS_HPP
#include <algorithm>
#include <numeric>
...
...
@@ -8,7 +8,7 @@
#include <migraphx/config.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPH_INLINE_NS
{
inline
namespace
MIGRAPH
X
_INLINE_NS
{
inline
std
::
string
replace_string
(
std
::
string
subject
,
const
std
::
string
&
search
,
const
std
::
string
&
replace
)
...
...
@@ -87,7 +87,7 @@ inline std::string to_string(const T& x)
return
ss
.
str
();
}
}
// namespace MIGRAPH_INLINE_NS
}
// namespace MIGRAPH
X
_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/target.hpp
View file @
a5c1c7f6
#ifndef MIGRAPH_GUARD_MIGRAPHLIB_TARGET_HPP
#define MIGRAPH_GUARD_MIGRAPHLIB_TARGET_HPP
#ifndef MIGRAPH
X
_GUARD_MIGRAPHLIB_TARGET_HPP
#define MIGRAPH
X
_GUARD_MIGRAPHLIB_TARGET_HPP
#include <cassert>
#include <string>
...
...
@@ -13,7 +13,7 @@
#include <migraphx/config.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPH_INLINE_NS
{
inline
namespace
MIGRAPH
X
_INLINE_NS
{
#ifdef DOXYGEN
...
...
@@ -127,6 +127,12 @@ struct target
return
(
*
this
).
private_detail_te_get_handle
().
get_context
();
}
friend
bool
is_shared
(
const
target
&
private_detail_x
,
const
target
&
private_detail_y
)
{
return
private_detail_x
.
private_detail_te_handle_mem_var
==
private_detail_y
.
private_detail_te_handle_mem_var
;
}
private:
struct
private_detail_te_handle_base_type
{
...
...
@@ -244,7 +250,7 @@ inline const ValueType& any_cast(const target& x)
#endif
}
// namespace MIGRAPH_INLINE_NS
}
// namespace MIGRAPH
X
_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/tensor_view.hpp
View file @
a5c1c7f6
#ifndef MIGRAPH_GUARD_TENSOR_VIEW_HPP
#define MIGRAPH_GUARD_TENSOR_VIEW_HPP
#ifndef MIGRAPH
X
_GUARD_TENSOR_VIEW_HPP
#define MIGRAPH
X
_GUARD_TENSOR_VIEW_HPP
#include <migraphx/shape.hpp>
#include <migraphx/float_equal.hpp>
...
...
@@ -10,7 +10,7 @@
#include <utility>
namespace
migraphx
{
inline
namespace
MIGRAPH_INLINE_NS
{
inline
namespace
MIGRAPH
X
_INLINE_NS
{
template
<
class
T
>
struct
tensor_view
...
...
@@ -29,7 +29,7 @@ struct tensor_view
const
T
*
data
()
const
{
return
this
->
m_data
;
}
template
<
class
...
Ts
,
MIGRAPH_REQUIRES
(
std
::
is_integral
<
Ts
>{}...)
>
template
<
class
...
Ts
,
MIGRAPH
X
_REQUIRES
(
std
::
is_integral
<
Ts
>{}...)
>
const
T
&
operator
()(
Ts
...
xs
)
const
{
assert
(
std
::
vector
<
std
::
size_t
>
{
static_cast
<
std
::
size_t
>
(
xs
)...}
<
m_shape
.
lens
());
...
...
@@ -37,7 +37,7 @@ struct tensor_view
return
m_data
[
m_shape
.
index
({
static_cast
<
std
::
size_t
>
(
xs
)...})];
}
template
<
class
...
Ts
,
MIGRAPH_REQUIRES
(
std
::
is_integral
<
Ts
>{}...)
>
template
<
class
...
Ts
,
MIGRAPH
X
_REQUIRES
(
std
::
is_integral
<
Ts
>{}...)
>
T
&
operator
()(
Ts
...
xs
)
{
assert
(
std
::
vector
<
std
::
size_t
>
{
static_cast
<
std
::
size_t
>
(
xs
)...}
<
m_shape
.
lens
());
...
...
@@ -45,13 +45,13 @@ struct tensor_view
return
m_data
[
m_shape
.
index
({
static_cast
<
std
::
size_t
>
(
xs
)...})];
}
template
<
class
Iterator
,
MIGRAPH_REQUIRES
(
not
std
::
is_integral
<
Iterator
>{})
>
template
<
class
Iterator
,
MIGRAPH
X
_REQUIRES
(
not
std
::
is_integral
<
Iterator
>{})
>
const
T
&
operator
()(
Iterator
start
,
Iterator
last
)
const
{
return
m_data
[
m_shape
.
index
(
start
,
last
)];
}
template
<
class
Iterator
,
MIGRAPH_REQUIRES
(
not
std
::
is_integral
<
Iterator
>{})
>
template
<
class
Iterator
,
MIGRAPH
X
_REQUIRES
(
not
std
::
is_integral
<
Iterator
>{})
>
T
&
operator
()(
Iterator
start
,
Iterator
last
)
{
return
m_data
[
m_shape
.
index
(
start
,
last
)];
...
...
@@ -164,12 +164,12 @@ bool operator!=(const tensor_view<T>& x, const tensor_view<U>& y)
}
template
<
class
T
>
tensor_view
<
T
>
make_view
(
shape
s
,
T
*
data
)
tensor_view
<
T
>
make_view
(
const
shape
&
s
,
T
*
data
)
{
return
{
s
,
data
};
}
}
// namespace MIGRAPH_INLINE_NS
}
// namespace MIGRAPH
X
_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/time.hpp
View file @
a5c1c7f6
#ifndef MIGRAPH_GUARD_RTGLIB_TIME_HPP
#define MIGRAPH_GUARD_RTGLIB_TIME_HPP
#ifndef MIGRAPH
X
_GUARD_RTGLIB_TIME_HPP
#define MIGRAPH
X
_GUARD_RTGLIB_TIME_HPP
#include <chrono>
#include <migraphx/config.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPH_INLINE_NS
{
inline
namespace
MIGRAPH
X
_INLINE_NS
{
template
<
class
Duration
,
class
F
>
auto
time
(
F
f
)
...
...
@@ -16,7 +16,7 @@ auto time(F f)
return
std
::
chrono
::
duration_cast
<
Duration
>
(
finish
-
start
).
count
();
}
}
// namespace MIGRAPH_INLINE_NS
}
// namespace MIGRAPH
X
_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/tracer.hpp
View file @
a5c1c7f6
#ifndef MIGRAPH_GUARD_RTGLIB_TRACER_HPP
#define MIGRAPH_GUARD_RTGLIB_TRACER_HPP
#ifndef MIGRAPH
X
_GUARD_RTGLIB_TRACER_HPP
#define MIGRAPH
X
_GUARD_RTGLIB_TRACER_HPP
#include <ostream>
#include <migraphx/functional.hpp>
#include <migraphx/config.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPH_INLINE_NS
{
inline
namespace
MIGRAPH
X
_INLINE_NS
{
struct
tracer
{
...
...
@@ -30,7 +30,7 @@ struct tracer
std
::
ostream
*
os
=
nullptr
;
};
}
// namespace MIGRAPH_INLINE_NS
}
// namespace MIGRAPH
X
_INLINE_NS
}
// namespace migraphx
#endif
Prev
1
2
3
4
5
6
7
8
…
16
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment