Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
eb0d8fee
Commit
eb0d8fee
authored
Jun 04, 2019
by
Paul
Browse files
Merge branch 'develop' into driver
parents
65ef35cd
0d796941
Changes
320
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1185 additions
and
251 deletions
+1185
-251
src/include/migraphx/raw_data.hpp
src/include/migraphx/raw_data.hpp
+12
-5
src/include/migraphx/reflect.hpp
src/include/migraphx/reflect.hpp
+61
-6
src/include/migraphx/rewrite_rnn.hpp
src/include/migraphx/rewrite_rnn.hpp
+13
-1
src/include/migraphx/schedule.hpp
src/include/migraphx/schedule.hpp
+28
-0
src/include/migraphx/schedule_model.hpp
src/include/migraphx/schedule_model.hpp
+286
-0
src/include/migraphx/streamutils.hpp
src/include/migraphx/streamutils.hpp
+3
-1
src/include/migraphx/stringutils.hpp
src/include/migraphx/stringutils.hpp
+3
-2
src/include/migraphx/target.hpp
src/include/migraphx/target.hpp
+0
-2
src/include/migraphx/tensor_view.hpp
src/include/migraphx/tensor_view.hpp
+12
-2
src/include/migraphx/tf.hpp
src/include/migraphx/tf.hpp
+16
-0
src/instruction.cpp
src/instruction.cpp
+30
-8
src/onnx/cifar10.cpp
src/onnx/cifar10.cpp
+1
-1
src/onnx/onnx.cpp
src/onnx/onnx.cpp
+410
-94
src/opt/common_header.hpp
src/opt/common_header.hpp
+0
-30
src/opt/memory_coloring_impl.cpp
src/opt/memory_coloring_impl.cpp
+10
-19
src/opt/memory_coloring_impl.hpp
src/opt/memory_coloring_impl.hpp
+33
-18
src/pass_manager.cpp
src/pass_manager.cpp
+42
-0
src/program.cpp
src/program.cpp
+149
-46
src/propagate_constant.cpp
src/propagate_constant.cpp
+54
-0
src/py/CMakeLists.txt
src/py/CMakeLists.txt
+22
-16
No files found.
src/include/migraphx/raw_data.hpp
View file @
eb0d8fee
...
...
@@ -27,7 +27,8 @@ struct raw_data : raw_data_base
template
<
class
Stream
>
friend
Stream
&
operator
<<
(
Stream
&
os
,
const
Derived
&
d
)
{
d
.
visit
([
&
](
auto
x
)
{
os
<<
x
;
});
if
(
not
d
.
empty
())
d
.
visit
([
&
](
auto
x
)
{
os
<<
x
;
});
return
os
;
}
...
...
@@ -40,8 +41,11 @@ struct raw_data : raw_data_base
template
<
class
Visitor
>
void
visit_at
(
Visitor
v
,
std
::
size_t
n
=
0
)
const
{
auto
&&
s
=
static_cast
<
const
Derived
&>
(
*
this
).
get_shape
();
auto
&&
buffer
=
static_cast
<
const
Derived
&>
(
*
this
).
data
();
auto
&&
derived
=
static_cast
<
const
Derived
&>
(
*
this
);
if
(
derived
.
empty
())
MIGRAPHX_THROW
(
"Visiting empty data!"
);
auto
&&
s
=
derived
.
get_shape
();
auto
&&
buffer
=
derived
.
data
();
s
.
visit_type
([
&
](
auto
as
)
{
v
(
*
(
as
.
from
(
buffer
)
+
s
.
index
(
n
)));
});
}
...
...
@@ -55,8 +59,11 @@ struct raw_data : raw_data_base
template
<
class
Visitor
>
void
visit
(
Visitor
v
)
const
{
auto
&&
s
=
static_cast
<
const
Derived
&>
(
*
this
).
get_shape
();
auto
&&
buffer
=
static_cast
<
const
Derived
&>
(
*
this
).
data
();
auto
&&
derived
=
static_cast
<
const
Derived
&>
(
*
this
);
if
(
derived
.
empty
())
MIGRAPHX_THROW
(
"Visiting empty data!"
);
auto
&&
s
=
derived
.
get_shape
();
auto
&&
buffer
=
derived
.
data
();
s
.
visit_type
([
&
](
auto
as
)
{
v
(
make_view
(
s
,
as
.
from
(
buffer
)));
});
}
...
...
src/include/migraphx/reflect.hpp
View file @
eb0d8fee
...
...
@@ -11,6 +11,15 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace
detail
{
struct
reflect_placeholder
{
template
<
class
...
Ts
>
int
operator
()(
Ts
&&
...)
const
{
return
0
;
}
};
template
<
class
T
,
class
Selector
>
auto
reflect_impl
(
rank
<
1
>
,
T
&
x
,
Selector
f
)
->
decltype
(
T
::
reflect
(
x
,
f
))
{
...
...
@@ -23,8 +32,53 @@ auto reflect_impl(rank<0>, T&, Selector)
return
pack
();
}
template
<
class
T
>
auto
reflectable_impl
(
rank
<
1
>
,
T
&&
x
)
->
decltype
(
T
::
reflect
(
x
,
reflect_placeholder
{}),
std
::
true_type
{});
template
<
class
T
>
auto
reflectable_impl
(
rank
<
0
>
,
T
&&
)
->
decltype
(
std
::
false_type
{});
template
<
class
T
>
struct
remove_rvalue_reference
{
using
type
=
T
;
};
template
<
class
T
>
struct
remove_rvalue_reference
<
T
&&>
{
using
type
=
T
;
};
template
<
class
T
>
struct
wrapper
{
using
type
=
typename
remove_rvalue_reference
<
T
>::
type
;
type
data
;
type
get
()
const
{
return
data
;
}
};
template
<
class
T
>
wrapper
<
T
>
wrap
(
std
::
remove_reference_t
<
T
>&
x
)
{
return
wrapper
<
T
>
{
std
::
forward
<
T
>
(
x
)};
}
template
<
class
...
Ts
>
using
auto_tuple_t
=
std
::
tuple
<
typename
remove_rvalue_reference
<
Ts
>::
type
...
>
;
template
<
class
...
Ts
>
auto_tuple_t
<
Ts
...
>
auto_tuple
(
Ts
&&
...
xs
)
{
return
auto_tuple_t
<
Ts
...
>
{
std
::
forward
<
Ts
>
(
xs
)...};
}
}
// namespace detail
template
<
class
T
>
using
is_reflectable
=
decltype
(
detail
::
reflectable_impl
(
rank
<
1
>
{},
std
::
declval
<
T
>
()));
template
<
class
T
,
class
Selector
>
auto
reflect
(
T
&
x
,
Selector
f
)
{
...
...
@@ -34,17 +88,18 @@ auto reflect(T& x, Selector f)
template
<
class
T
>
auto
reflect_tie
(
T
&
x
)
{
return
reflect
(
x
,
[](
auto
&&
y
,
auto
&&
...)
{
return
std
::
ref
(
y
);
})(
[](
auto
&&
...
xs
)
{
return
std
::
ti
e
(
xs
.
get
()...);
});
return
reflect
(
x
,
[](
auto
&&
y
,
auto
&&
...)
{
return
detail
::
wrap
<
decltype
(
y
)
>
(
y
);
})(
[](
auto
&&
...
xs
)
{
return
detail
::
auto_tupl
e
(
xs
.
get
()...);
});
}
template
<
class
T
,
class
F
>
void
reflect_each
(
T
&
x
,
F
f
)
{
return
reflect
(
x
,
[](
auto
&&
y
,
auto
...
ys
)
{
return
pack
(
std
::
ref
(
y
),
ys
...);
})(
[
&
](
auto
&&
...
xs
)
{
each_args
([
&
](
auto
p
)
{
p
([
&
](
auto
&&
y
,
auto
...
ys
)
{
f
(
y
.
get
(),
ys
...);
});
},
xs
...);
});
return
reflect
(
x
,
[](
auto
&&
y
,
auto
...
ys
)
{
return
pack
(
detail
::
wrap
<
decltype
(
y
)
>
(
y
),
ys
...);
})([
&
](
auto
&&
...
xs
)
{
each_args
([
&
](
auto
p
)
{
p
([
&
](
auto
&&
y
,
auto
...
ys
)
{
f
(
y
.
get
(),
ys
...);
});
},
xs
...);
});
}
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/include/migraphx/rewrite_rnn.hpp
View file @
eb0d8fee
...
...
@@ -4,7 +4,7 @@
#include <string>
#include <vector>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/operat
ors
.hpp>
#include <migraphx/operat
ion
.hpp>
#include <migraphx/config.hpp>
namespace
migraphx
{
...
...
@@ -45,6 +45,18 @@ struct rewrite_rnn
const
operation
&
actv_func2
)
const
;
std
::
vector
<
operation
>
gru_actv_funcs
(
instruction_ref
ins
)
const
;
// for lstm operators
void
apply_lstm
(
program
&
prog
,
instruction_ref
ins
)
const
;
std
::
vector
<
instruction_ref
>
lstm_cell
(
bool
is_forward
,
program
&
prog
,
instruction_ref
ins
,
std
::
vector
<
instruction_ref
>
inputs
,
const
operation
&
actv_func1
,
const
operation
&
actv_func2
,
const
operation
&
actv_func3
)
const
;
std
::
vector
<
operation
>
lstm_actv_funcs
(
instruction_ref
ins
)
const
;
};
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/include/migraphx/schedule.hpp
0 → 100644
View file @
eb0d8fee
#ifndef MIGRAPHX_GUARD_RTGLIB_SCHEDULE_HPP
#define MIGRAPHX_GUARD_RTGLIB_SCHEDULE_HPP
#include <string>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/schedule_model.hpp>
#include <migraphx/config.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
struct
program
;
/**
* Schedule instructions for concurrent execution
*/
struct
schedule
{
schedule_model
model
{};
bool
enable
=
true
;
std
::
string
name
()
const
{
return
"schedule"
;
}
void
apply
(
program
&
p
)
const
;
};
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/schedule_model.hpp
0 → 100644
View file @
eb0d8fee
#ifndef MIGRAPHX_GUARD_SCHEDULE_MODEL_HPP
#define MIGRAPHX_GUARD_SCHEDULE_MODEL_HPP
#include <cassert>
#include <string>
#include <functional>
#include <memory>
#include <type_traits>
#include <utility>
#include <migraphx/config.hpp>
#include <migraphx/instruction_ref.hpp>
#include <vector>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
struct
program
;
struct
operation
;
#ifdef DOXYGEN
/// An interface for target-dependent model for the scheduler
struct
schedule_model
{
/// Get the number of concurrent instruction allowed
std
::
size_t
concurrency
()
const
;
/// Schedule a concurrent instruction
void
sched
(
program
&
p
,
instruction_ref
ins
,
std
::
size_t
n
)
const
;
// Insert necessary waits before an instruction
void
wait
(
program
&
p
,
instruction_ref
ins
,
std
::
size_t
wait_id
)
const
;
// Insert necessary records after an instruction
void
record
(
program
&
p
,
instruction_ref
ins
,
std
::
size_t
wait_id
)
const
;
/// Compute weights for an operation
std
::
size_t
weight
(
const
operation
&
op
)
const
;
};
#else
/*
* Type-erased interface for:
*
* struct schedule_model
* {
* std::size_t concurrency() const;
* void sched(program& p,instruction_ref ins,std::size_t n) const;
* void wait(program& p,instruction_ref ins,std::size_t wait_id) const;
* void record(program& p,instruction_ref ins,std::size_t wait_id) const;
* std::size_t weight(const operation& op) const;
* };
*
*/
struct
schedule_model
{
// Constructors
schedule_model
()
=
default
;
template
<
typename
PrivateDetailTypeErasedT
>
schedule_model
(
PrivateDetailTypeErasedT
value
)
:
private_detail_te_handle_mem_var
(
std
::
make_shared
<
private_detail_te_handle_type
<
typename
std
::
remove_reference
<
PrivateDetailTypeErasedT
>::
type
>>
(
std
::
forward
<
PrivateDetailTypeErasedT
>
(
value
)))
{
}
// Assignment
template
<
typename
PrivateDetailTypeErasedT
>
schedule_model
&
operator
=
(
PrivateDetailTypeErasedT
value
)
{
if
(
private_detail_te_handle_mem_var
.
unique
())
*
private_detail_te_handle_mem_var
=
std
::
forward
<
PrivateDetailTypeErasedT
>
(
value
);
else
if
(
!
private_detail_te_handle_mem_var
)
private_detail_te_handle_mem_var
=
std
::
make_shared
<
PrivateDetailTypeErasedT
>
(
std
::
forward
<
PrivateDetailTypeErasedT
>
(
value
));
return
*
this
;
}
// Cast
template
<
typename
PrivateDetailTypeErasedT
>
PrivateDetailTypeErasedT
*
any_cast
()
{
return
private_detail_te_get_handle
().
type
()
==
typeid
(
PrivateDetailTypeErasedT
)
?
std
::
addressof
(
static_cast
<
private_detail_te_handle_type
<
typename
std
::
remove_cv
<
PrivateDetailTypeErasedT
>::
type
>&>
(
private_detail_te_get_handle
())
.
private_detail_te_value
)
:
nullptr
;
}
template
<
typename
PrivateDetailTypeErasedT
>
const
typename
std
::
remove_cv
<
PrivateDetailTypeErasedT
>::
type
*
any_cast
()
const
{
return
private_detail_te_get_handle
().
type
()
==
typeid
(
PrivateDetailTypeErasedT
)
?
std
::
addressof
(
static_cast
<
const
private_detail_te_handle_type
<
typename
std
::
remove_cv
<
PrivateDetailTypeErasedT
>::
type
>&>
(
private_detail_te_get_handle
())
.
private_detail_te_value
)
:
nullptr
;
}
const
std
::
type_info
&
type_id
()
const
{
if
(
private_detail_te_handle_empty
())
return
typeid
(
std
::
nullptr_t
);
else
return
private_detail_te_get_handle
().
type
();
}
std
::
size_t
concurrency
()
const
{
assert
((
*
this
).
private_detail_te_handle_mem_var
);
return
(
*
this
).
private_detail_te_get_handle
().
concurrency
();
}
void
sched
(
program
&
p
,
instruction_ref
ins
,
std
::
size_t
n
)
const
{
assert
((
*
this
).
private_detail_te_handle_mem_var
);
(
*
this
).
private_detail_te_get_handle
().
sched
(
p
,
ins
,
n
);
}
void
wait
(
program
&
p
,
instruction_ref
ins
,
std
::
size_t
wait_id
)
const
{
assert
((
*
this
).
private_detail_te_handle_mem_var
);
(
*
this
).
private_detail_te_get_handle
().
wait
(
p
,
ins
,
wait_id
);
}
void
record
(
program
&
p
,
instruction_ref
ins
,
std
::
size_t
wait_id
)
const
{
assert
((
*
this
).
private_detail_te_handle_mem_var
);
(
*
this
).
private_detail_te_get_handle
().
record
(
p
,
ins
,
wait_id
);
}
std
::
size_t
weight
(
const
operation
&
op
)
const
{
assert
((
*
this
).
private_detail_te_handle_mem_var
);
return
(
*
this
).
private_detail_te_get_handle
().
weight
(
op
);
}
friend
bool
is_shared
(
const
schedule_model
&
private_detail_x
,
const
schedule_model
&
private_detail_y
)
{
return
private_detail_x
.
private_detail_te_handle_mem_var
==
private_detail_y
.
private_detail_te_handle_mem_var
;
}
private:
struct
private_detail_te_handle_base_type
{
virtual
~
private_detail_te_handle_base_type
()
{}
virtual
std
::
shared_ptr
<
private_detail_te_handle_base_type
>
clone
()
const
=
0
;
virtual
const
std
::
type_info
&
type
()
const
=
0
;
virtual
std
::
size_t
concurrency
()
const
=
0
;
virtual
void
sched
(
program
&
p
,
instruction_ref
ins
,
std
::
size_t
n
)
const
=
0
;
virtual
void
wait
(
program
&
p
,
instruction_ref
ins
,
std
::
size_t
wait_id
)
const
=
0
;
virtual
void
record
(
program
&
p
,
instruction_ref
ins
,
std
::
size_t
wait_id
)
const
=
0
;
virtual
std
::
size_t
weight
(
const
operation
&
op
)
const
=
0
;
};
template
<
typename
PrivateDetailTypeErasedT
>
struct
private_detail_te_handle_type
:
private_detail_te_handle_base_type
{
template
<
typename
PrivateDetailTypeErasedU
=
PrivateDetailTypeErasedT
>
private_detail_te_handle_type
(
PrivateDetailTypeErasedT
value
,
typename
std
::
enable_if
<
std
::
is_reference
<
PrivateDetailTypeErasedU
>::
value
>::
type
*
=
nullptr
)
:
private_detail_te_value
(
value
)
{
}
template
<
typename
PrivateDetailTypeErasedU
=
PrivateDetailTypeErasedT
>
private_detail_te_handle_type
(
PrivateDetailTypeErasedT
value
,
typename
std
::
enable_if
<!
std
::
is_reference
<
PrivateDetailTypeErasedU
>::
value
,
int
>::
type
*
=
nullptr
)
noexcept
:
private_detail_te_value
(
std
::
move
(
value
))
{
}
std
::
shared_ptr
<
private_detail_te_handle_base_type
>
clone
()
const
override
{
return
std
::
make_shared
<
private_detail_te_handle_type
>
(
private_detail_te_value
);
}
const
std
::
type_info
&
type
()
const
override
{
return
typeid
(
private_detail_te_value
);
}
std
::
size_t
concurrency
()
const
override
{
return
private_detail_te_value
.
concurrency
();
}
void
sched
(
program
&
p
,
instruction_ref
ins
,
std
::
size_t
n
)
const
override
{
private_detail_te_value
.
sched
(
p
,
ins
,
n
);
}
void
wait
(
program
&
p
,
instruction_ref
ins
,
std
::
size_t
wait_id
)
const
override
{
private_detail_te_value
.
wait
(
p
,
ins
,
wait_id
);
}
void
record
(
program
&
p
,
instruction_ref
ins
,
std
::
size_t
wait_id
)
const
override
{
private_detail_te_value
.
record
(
p
,
ins
,
wait_id
);
}
std
::
size_t
weight
(
const
operation
&
op
)
const
override
{
return
private_detail_te_value
.
weight
(
op
);
}
PrivateDetailTypeErasedT
private_detail_te_value
;
};
template
<
typename
PrivateDetailTypeErasedT
>
struct
private_detail_te_handle_type
<
std
::
reference_wrapper
<
PrivateDetailTypeErasedT
>>
:
private_detail_te_handle_type
<
PrivateDetailTypeErasedT
&>
{
private_detail_te_handle_type
(
std
::
reference_wrapper
<
PrivateDetailTypeErasedT
>
ref
)
:
private_detail_te_handle_type
<
PrivateDetailTypeErasedT
&>
(
ref
.
get
())
{
}
};
bool
private_detail_te_handle_empty
()
const
{
return
private_detail_te_handle_mem_var
==
nullptr
;
}
const
private_detail_te_handle_base_type
&
private_detail_te_get_handle
()
const
{
assert
(
private_detail_te_handle_mem_var
!=
nullptr
);
return
*
private_detail_te_handle_mem_var
;
}
private_detail_te_handle_base_type
&
private_detail_te_get_handle
()
{
assert
(
private_detail_te_handle_mem_var
!=
nullptr
);
if
(
!
private_detail_te_handle_mem_var
.
unique
())
private_detail_te_handle_mem_var
=
private_detail_te_handle_mem_var
->
clone
();
return
*
private_detail_te_handle_mem_var
;
}
std
::
shared_ptr
<
private_detail_te_handle_base_type
>
private_detail_te_handle_mem_var
;
};
template
<
typename
ValueType
>
inline
const
ValueType
*
any_cast
(
const
schedule_model
*
x
)
{
return
x
->
any_cast
<
ValueType
>
();
}
template
<
typename
ValueType
>
inline
ValueType
*
any_cast
(
schedule_model
*
x
)
{
return
x
->
any_cast
<
ValueType
>
();
}
template
<
typename
ValueType
>
inline
ValueType
&
any_cast
(
schedule_model
&
x
)
{
auto
*
y
=
x
.
any_cast
<
typename
std
::
remove_reference
<
ValueType
>::
type
>
();
if
(
y
==
nullptr
)
throw
std
::
bad_cast
();
return
*
y
;
}
template
<
typename
ValueType
>
inline
const
ValueType
&
any_cast
(
const
schedule_model
&
x
)
{
const
auto
*
y
=
x
.
any_cast
<
typename
std
::
remove_reference
<
ValueType
>::
type
>
();
if
(
y
==
nullptr
)
throw
std
::
bad_cast
();
return
*
y
;
}
#endif
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/streamutils.hpp
View file @
eb0d8fee
...
...
@@ -36,6 +36,8 @@ inline stream_range_container<Range> stream_range(const Range& r)
namespace
detail
{
inline
void
stream_write_value_impl
(
rank
<
2
>
,
std
::
ostream
&
os
,
const
std
::
string
&
x
)
{
os
<<
x
;
}
template
<
class
Range
>
auto
stream_write_value_impl
(
rank
<
1
>
,
std
::
ostream
&
os
,
const
Range
&
r
)
->
decltype
(
r
.
begin
(),
r
.
end
(),
void
())
...
...
@@ -53,7 +55,7 @@ void stream_write_value_impl(rank<0>, std::ostream& os, const T& x)
template
<
class
T
>
void
stream_write_value
(
std
::
ostream
&
os
,
const
T
&
x
)
{
detail
::
stream_write_value_impl
(
rank
<
1
>
{},
os
,
x
);
detail
::
stream_write_value_impl
(
rank
<
2
>
{},
os
,
x
);
}
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/include/migraphx/stringutils.hpp
View file @
eb0d8fee
...
...
@@ -38,8 +38,9 @@ inline std::string join_strings(Strings strings, const std::string& delim)
return
""
;
auto
nit
=
std
::
next
(
it
);
return
std
::
accumulate
(
nit
,
strings
.
end
(),
*
it
,
[
&
](
std
::
string
x
,
std
::
string
y
)
{
return
x
+
delim
+
y
;
});
return
std
::
accumulate
(
nit
,
strings
.
end
(),
*
it
,
[
&
](
std
::
string
x
,
std
::
string
y
)
{
return
std
::
move
(
x
)
+
delim
+
std
::
move
(
y
);
});
}
template
<
class
F
>
...
...
src/include/migraphx/target.hpp
View file @
eb0d8fee
...
...
@@ -22,10 +22,8 @@ struct target
{
/// A unique name used to identify the target
std
::
string
name
()
const
;
/// The transformation passes to be run
/**
* @brief The transformation pass to be run during compilation.
* @details [long description]
*
* @param ctx This is the target-dependent context that is created by `get_context`
* @return The passes to be ran
...
...
src/include/migraphx/tensor_view.hpp
View file @
eb0d8fee
...
...
@@ -12,6 +12,14 @@
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
template
<
class
T
>
T
as_number
(
T
x
)
{
return
x
;
}
inline
int32_t
as_number
(
int8_t
x
)
{
return
static_cast
<
int32_t
>
(
x
);
}
inline
uint32_t
as_number
(
uint8_t
x
)
{
return
static_cast
<
uint32_t
>
(
x
);
}
template
<
class
T
>
struct
tensor_view
{
...
...
@@ -124,14 +132,16 @@ struct tensor_view
return
m_data
+
this
->
size
();
}
std
::
vector
<
T
>
to_vector
()
const
{
return
std
::
vector
<
T
>
(
this
->
begin
(),
this
->
end
());
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
tensor_view
<
T
>&
x
)
{
if
(
!
x
.
empty
())
{
os
<<
x
.
front
();
os
<<
as_number
(
x
.
front
()
)
;
for
(
std
::
size_t
i
=
1
;
i
<
x
.
m_shape
.
elements
();
i
++
)
{
os
<<
", "
<<
x
.
m_data
[
x
.
m_shape
.
index
(
i
)];
os
<<
", "
<<
as_number
(
x
.
m_data
[
x
.
m_shape
.
index
(
i
)]
)
;
}
}
return
os
;
...
...
src/include/migraphx/tf.hpp
0 → 100644
View file @
eb0d8fee
#ifndef MIGRAPHX_GUARD_MIGRAPHLIB_TF_HPP
#define MIGRAPHX_GUARD_MIGRAPHLIB_TF_HPP
#include <migraphx/program.hpp>
#include <migraphx/config.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
/// Create a program from a tf pb file (default is nhwc format)
program
parse_tf
(
const
std
::
string
&
name
,
bool
is_nhwc
);
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/instruction.cpp
View file @
eb0d8fee
...
...
@@ -28,6 +28,12 @@ void instruction::replace(const shape& r)
}
}
void
instruction
::
replace
(
operation
o
)
{
op
=
std
::
move
(
o
);
recompute_shape
();
}
void
instruction
::
recompute_shape
()
{
replace
(
compute_shape
(
op
,
arguments
));
}
void
instruction
::
clear_arguments
()
...
...
@@ -162,7 +168,24 @@ void instruction::replace_argument(instruction_ref old, instruction_ref new_ins)
old
->
remove_output
(
*
this
);
}
argument
instruction
::
eval
()
const
bool
instruction
::
can_eval
()
const
{
if
(
op
.
name
()
==
"@literal"
)
{
return
true
;
}
else
if
(
is_context_free
(
op
))
{
return
std
::
all_of
(
this
->
inputs
().
begin
(),
this
->
inputs
().
end
(),
[](
auto
arg
)
{
return
arg
->
can_eval
();
});
}
else
{
return
false
;
}
}
argument
instruction
::
eval
(
bool
check_eval
)
const
{
if
(
op
.
name
()
==
"@literal"
)
{
...
...
@@ -170,14 +193,13 @@ argument instruction::eval() const
}
if
(
is_context_free
(
op
))
{
if
(
check_eval
and
not
this
->
can_eval
())
return
{};
std
::
vector
<
argument
>
args
;
for
(
auto
&&
arg
:
this
->
inputs
())
{
argument
a
=
arg
->
eval
();
if
(
a
.
empty
())
return
{};
args
.
push_back
(
a
);
}
std
::
transform
(
this
->
inputs
().
begin
(),
this
->
inputs
().
end
(),
std
::
back_inserter
(
args
),
[](
auto
arg
)
{
return
arg
->
eval
(
false
);
});
return
op
.
compute
(
result
,
args
);
}
return
{};
...
...
src/onnx/cifar10.cpp
View file @
eb0d8fee
...
...
@@ -32,7 +32,7 @@ auto read_cifar10_images(const std::string& full_path)
labels
[
i
]
=
*
pimage
++
;
for
(
size_t
j
=
0
;
j
<
nbytes_per_image
;
j
++
)
{
float
v
=
*
(
pimage
+
j
)
/
255.0
f
;
float
v
=
float
(
*
(
pimage
+
j
)
)
/
255.0
f
;
data
[
i
*
nbytes_per_image
+
j
]
=
v
;
}
}
...
...
src/onnx/onnx.cpp
View file @
eb0d8fee
...
...
@@ -36,7 +36,6 @@ struct onnx_parser
onnx_parser
()
{
add_generic_op
(
"MatMul"
,
op
::
dot
{});
add_generic_op
(
"Relu"
,
op
::
relu
{});
add_generic_op
(
"Sigmoid"
,
op
::
sigmoid
{});
add_generic_op
(
"Abs"
,
op
::
abs
{});
...
...
@@ -64,6 +63,7 @@ struct onnx_parser
add_variadic_op
(
"Max"
,
op
::
max
{});
add_variadic_op
(
"Min"
,
op
::
min
{});
add_mem_op
(
"Clip"
,
&
onnx_parser
::
parse_clip
);
add_mem_op
(
"LRN"
,
&
onnx_parser
::
parse_lrn
);
add_mem_op
(
"ImageScaler"
,
&
onnx_parser
::
parse_imagescaler
);
add_mem_op
(
"LeakyRelu"
,
&
onnx_parser
::
parse_leaky_relu
);
...
...
@@ -77,8 +77,10 @@ struct onnx_parser
add_mem_op
(
"Reshape"
,
&
onnx_parser
::
parse_reshape
);
add_mem_op
(
"Flatten"
,
&
onnx_parser
::
parse_flatten
);
add_mem_op
(
"Gemm"
,
&
onnx_parser
::
parse_gemm
);
add_mem_op
(
"MatMul"
,
&
onnx_parser
::
parse_matmul
);
add_mem_op
(
"BatchNormalization"
,
&
onnx_parser
::
parse_batchnorm
);
add_mem_op
(
"Softmax"
,
&
onnx_parser
::
parse_softmax
);
add_mem_op
(
"LogSoftmax"
,
&
onnx_parser
::
parse_logsoftmax
);
add_mem_op
(
"Squeeze"
,
&
onnx_parser
::
parse_squeeze
);
add_mem_op
(
"Unsqueeze"
,
&
onnx_parser
::
parse_unsqueeze
);
add_mem_op
(
"Slice"
,
&
onnx_parser
::
parse_slice
);
...
...
@@ -89,6 +91,7 @@ struct onnx_parser
add_mem_op
(
"Transpose"
,
&
onnx_parser
::
parse_transpose
);
add_mem_op
(
"RNN"
,
&
onnx_parser
::
parse_rnn
);
add_mem_op
(
"GRU"
,
&
onnx_parser
::
parse_gru
);
add_mem_op
(
"LSTM"
,
&
onnx_parser
::
parse_lstm
);
add_mem_op
(
"Pad"
,
&
onnx_parser
::
parse_pad
);
// init the activation function map
...
...
@@ -139,8 +142,8 @@ struct onnx_parser
if
(
broadcasted
!=
0
)
{
uint64_t
axis
=
parse_value
(
attributes
.
at
(
"axis"
)).
at
<
uint64_t
>
();
auto
l
=
prog
.
add_instruction
(
op
::
broadcast
{
axis
,
args
[
0
]
->
get_shape
()},
args
[
1
]);
auto
l
=
prog
.
add_instruction
(
op
::
broadcast
{
axis
,
args
[
0
]
->
get_shape
().
lens
()},
args
[
1
]);
return
prog
.
add_instruction
(
x
,
args
[
0
],
l
);
}
return
prog
.
add_instruction
(
x
,
args
);
...
...
@@ -152,42 +155,48 @@ struct onnx_parser
});
}
std
::
vector
<
std
::
size_t
>
compute_broadcasted_lens
(
std
::
vector
<
std
::
size_t
>
s0
,
std
::
vector
<
std
::
size_t
>
s1
)
{
// Example:
// s0 = (3,2,4,5) and s1 = (2,1,1)
//
// In this case we need to broadcast (:,1,1) portion of
// s1 plus broadcast the 1st dimension of s1
// giving output_lens = (3,2,4,5)
//
// Another example:
// s0 = (3,2,1,5) and s1 = (2,7,5)
// In this case we need to broadcast the (:,:,1:,:) axis
// of s0 plus the 1st dimension of s1 giving
// output_lens = (3,2,7,5)
if
(
s0
.
size
()
>
s1
.
size
())
{
s0
.
swap
(
s1
);
}
std
::
vector
<
std
::
size_t
>
out_lens
(
s1
);
auto
offset
=
s1
.
size
()
-
s0
.
size
();
std
::
transform
(
s0
.
begin
(),
s0
.
end
(),
s1
.
begin
()
+
offset
,
out_lens
.
begin
()
+
offset
,
[](
auto
a
,
auto
b
)
{
return
std
::
max
(
a
,
b
);
});
return
out_lens
;
}
template
<
class
T
>
instruction_ref
add_broadcastable_binary_op
(
instruction_ref
arg0
,
instruction_ref
arg1
,
T
x
)
{
if
(
arg0
->
get_shape
()
!=
arg1
->
get_shape
())
{
// Example:
// s0 = (3,2,4,5) and s1 = (2,1,1)
//
// In this case we need to broadcast (:,1,1) portion of
// s1 plus broadcast the 1st dimension of s1
// giving output_lens = (3,2,4,5)
//
// Another example:
// s0 = (3,2,1,5) and s1 = (2,7,5)
// In this case we need to broadcast the (:,:,1:,:) axis
// of s0 plus the 1st dimension of s1 giving
// output_lens = (3,2,7,5)
//
if
(
arg0
->
get_shape
().
lens
()
!=
arg1
->
get_shape
().
lens
())
{
// Get lengths for both arguments
const
std
::
vector
<
std
::
size_t
>*
s0
=
&
arg0
->
get_shape
().
lens
();
const
std
::
vector
<
std
::
size_t
>*
s1
=
&
arg1
->
get_shape
().
lens
();
// Make sure s0 is the smaller size
if
(
s0
->
size
()
>
s1
->
size
())
std
::
swap
(
s0
,
s1
);
std
::
vector
<
std
::
size_t
>
output_lens
(
*
s1
);
auto
offset
=
s1
->
size
()
-
s0
->
size
();
std
::
transform
(
s0
->
begin
(),
s0
->
end
(),
s1
->
begin
()
+
offset
,
output_lens
.
begin
()
+
offset
,
[](
auto
a
,
auto
b
)
{
return
std
::
max
(
a
,
b
);
});
auto
l0
=
prog
.
add_instruction
(
op
::
multibroadcast
{
output_lens
},
arg0
);
auto
l1
=
prog
.
add_instruction
(
op
::
multibroadcast
{
output_lens
},
arg1
);
auto
s0
=
arg0
->
get_shape
().
lens
();
auto
s1
=
arg1
->
get_shape
().
lens
();
auto
out_lens
=
compute_broadcasted_lens
(
s0
,
s1
);
auto
l0
=
prog
.
add_instruction
(
op
::
multibroadcast
{
out_lens
},
arg0
);
auto
l1
=
prog
.
add_instruction
(
op
::
multibroadcast
{
out_lens
},
arg1
);
return
prog
.
add_instruction
(
x
,
l0
,
l1
);
}
else
...
...
@@ -199,7 +208,7 @@ struct onnx_parser
template
<
class
T
>
void
add_generic_op
(
std
::
string
name
,
T
x
)
{
add_op
(
name
,
[
this
,
x
](
attribute_map
,
std
::
vector
<
instruction_ref
>
args
)
{
add_op
(
name
,
[
this
,
x
](
const
attribute_map
&
,
std
::
vector
<
instruction_ref
>
args
)
{
return
prog
.
add_instruction
(
x
,
args
);
});
}
...
...
@@ -207,7 +216,7 @@ struct onnx_parser
template
<
class
T
>
void
add_variadic_op
(
std
::
string
name
,
T
x
)
{
add_op
(
name
,
[
this
,
x
](
attribute_map
,
std
::
vector
<
instruction_ref
>
args
)
{
add_op
(
name
,
[
this
,
x
](
const
attribute_map
&
,
std
::
vector
<
instruction_ref
>
args
)
{
return
std
::
accumulate
(
std
::
next
(
args
.
begin
()),
args
.
end
(),
args
.
front
(),
...
...
@@ -217,6 +226,22 @@ struct onnx_parser
});
}
instruction_ref
parse_clip
(
const
std
::
string
&
,
const
attribute_map
&
attributes
,
std
::
vector
<
instruction_ref
>
args
)
{
op
::
clip
op
;
if
(
contains
(
attributes
,
"max"
))
{
op
.
max_val
=
parse_value
(
attributes
.
at
(
"max"
)).
at
<
float
>
();
}
if
(
contains
(
attributes
,
"min"
))
{
op
.
min_val
=
parse_value
(
attributes
.
at
(
"min"
)).
at
<
float
>
();
}
return
prog
.
add_instruction
(
op
,
std
::
move
(
args
));
}
instruction_ref
parse_softmax
(
const
std
::
string
&
,
const
attribute_map
&
,
std
::
vector
<
instruction_ref
>
args
)
{
...
...
@@ -227,6 +252,19 @@ struct onnx_parser
return
prog
.
add_instruction
(
op
::
reshape
{{
long
(
dims
[
0
]),
long
(
dims
[
1
])}},
s
);
}
instruction_ref
parse_logsoftmax
(
const
std
::
string
&
,
const
attribute_map
&
attributes
,
std
::
vector
<
instruction_ref
>
args
)
{
int
axis
=
1
;
if
(
contains
(
attributes
,
"axis"
))
{
axis
=
parse_value
(
attributes
.
at
(
"axis"
)).
at
<
int
>
();
}
return
prog
.
add_instruction
(
op
::
logsoftmax
{
axis
},
std
::
move
(
args
));
}
instruction_ref
parse_conv
(
const
std
::
string
&
,
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
args
)
{
...
...
@@ -285,7 +323,7 @@ struct onnx_parser
{
uint64_t
axis
=
1
;
auto
l1
=
prog
.
add_instruction
(
op
,
args
[
0
],
args
[
1
]);
auto
l2
=
prog
.
add_instruction
(
op
::
broadcast
{
axis
,
l1
->
get_shape
()},
args
[
2
]);
auto
l2
=
prog
.
add_instruction
(
op
::
broadcast
{
axis
,
l1
->
get_shape
()
.
lens
()
},
args
[
2
]);
return
prog
.
add_instruction
(
op
::
add
{},
l1
,
l2
);
}
return
prog
.
add_instruction
(
op
,
l0
,
args
[
1
]);
...
...
@@ -354,7 +392,9 @@ struct onnx_parser
}
if
(
args
.
size
()
==
2
)
{
literal
s
=
args
[
1
]
->
get_literal
();
auto
s
=
args
[
1
]
->
eval
();
if
(
s
.
empty
())
MIGRAPHX_THROW
(
"Dynamic shape is not supported."
);
s
.
visit
([
&
](
auto
v
)
{
copy
(
v
,
std
::
back_inserter
(
op
.
dims
));
});
}
return
prog
.
add_instruction
(
op
,
args
[
0
]);
...
...
@@ -433,7 +473,15 @@ struct onnx_parser
attribute_map
attributes
,
const
std
::
vector
<
instruction_ref
>&
)
{
literal
v
=
parse_value
(
attributes
.
at
(
"value"
));
literal
v
=
parse_value
(
attributes
.
at
(
"value"
));
auto
dim_size
=
attributes
.
at
(
"value"
).
t
().
dims_size
();
// if dim_size is 0, it is a scalar
if
(
dim_size
==
0
)
{
migraphx
::
shape
scalar_shape
{
v
.
get_shape
().
type
()};
return
prog
.
add_literal
(
migraphx
::
literal
{
scalar_shape
,
v
.
data
()});
}
return
prog
.
add_literal
(
v
);
}
...
...
@@ -460,29 +508,96 @@ struct onnx_parser
{
transb
=
parse_value
(
attributes
.
at
(
"transB"
)).
at
<
bool
>
();
}
std
::
vector
<
int64_t
>
perm
=
{
1
,
0
};
std
::
vector
<
int64_t
>
perm
(
args
[
0
]
->
get_shape
().
lens
().
size
());
std
::
iota
(
perm
.
begin
(),
perm
.
end
(),
int64_t
{
0
});
// swap the last two elements
std
::
swap
(
*
perm
.
rbegin
(),
*
(
perm
.
rbegin
()
+
1
));
auto
l1
=
(
transa
)
?
prog
.
add_instruction
(
op
::
transpose
{
perm
},
args
[
0
])
:
args
[
0
];
auto
l2
=
(
transb
)
?
prog
.
add_instruction
(
op
::
transpose
{
perm
},
args
[
1
])
:
args
[
1
];
if
(
args
.
size
()
==
3
)
{
if
(
beta
!=
0.
f
)
if
(
beta
!=
0.
f
&&
args
[
2
]
->
get_shape
().
elements
()
>
0
)
{
auto
l3
=
prog
.
add_instruction
(
op
::
dot
{
alpha
},
l1
,
l2
);
a
ut
o
l4
=
args
[
2
]
;
if
(
l4
->
get_shape
().
scalar
())
// ignore args[2] (no C value added to alpha*A*B)
return
l3
;
if
(
beta
!=
1.
f
)
auto
out_lens
=
l1
->
get_shape
().
lens
(
);
o
ut
_lens
.
back
()
=
l2
->
get_shape
().
lens
().
back
()
;
auto
l3
=
args
[
2
];
auto
l3_lens
=
l3
->
get_shape
().
lens
()
;
if
(
!
std
::
equal
(
out_lens
.
begin
(),
out_lens
.
end
(),
l3_lens
.
begin
(),
l3_lens
.
end
())
)
{
auto
beta_val
=
prog
.
add_literal
(
beta
);
auto
l5
=
prog
.
add_instruction
(
op
::
scalar
{
args
[
2
]
->
get_shape
()},
beta_val
);
l4
=
prog
.
add_instruction
(
op
::
mul
{},
args
[
2
],
l5
);
l3
=
prog
.
add_instruction
(
op
::
multibroadcast
{
out_lens
},
args
[
2
]);
}
return
add_broadcastable_binary_op
(
l3
,
l
4
,
op
::
add
{}
);
return
prog
.
add_instruction
(
op
::
dot
{
alpha
,
beta
},
l1
,
l
2
,
l3
);
}
}
return
prog
.
add_instruction
(
op
::
dot
{
alpha
,
beta
},
l1
,
l2
);
}
instruction_ref
parse_matmul
(
const
std
::
string
&
,
const
attribute_map
&
,
std
::
vector
<
instruction_ref
>
args
)
{
auto
l0
=
args
[
0
];
auto
l1
=
args
[
1
];
auto
l0_lens
=
l0
->
get_shape
().
lens
();
auto
l1_lens
=
l1
->
get_shape
().
lens
();
// args[0] is a vector, prepend 1 to the shape
bool
is_a_prepended
=
false
;
if
(
l0_lens
.
size
()
==
1
)
{
is_a_prepended
=
true
;
l0_lens
.
insert
(
l0_lens
.
begin
(),
1
);
l0
=
prog
.
add_instruction
(
op
::
unsqueeze
{{
0
}},
args
[
0
]);
}
bool
is_b_appended
=
false
;
if
(
l1_lens
.
size
()
==
1
)
{
is_b_appended
=
true
;
l1_lens
.
push_back
(
1
);
l1
=
prog
.
add_instruction
(
op
::
unsqueeze
{{
1
}},
args
[
1
]);
}
instruction_ref
bl0
=
l0
;
instruction_ref
bl1
=
l1
;
if
(
!
std
::
equal
(
l0_lens
.
rbegin
()
+
2
,
l0_lens
.
rend
(),
l1_lens
.
rbegin
()
+
2
,
l1_lens
.
rend
()))
{
auto
l0_it
=
l0_lens
.
begin
()
+
l0_lens
.
size
()
-
2
;
std
::
vector
<
std
::
size_t
>
l0_broadcasted_lens
(
l0_lens
.
begin
(),
l0_it
);
auto
l1_it
=
l1_lens
.
begin
()
+
l1_lens
.
size
()
-
2
;
std
::
vector
<
std
::
size_t
>
l1_broadcasted_lens
(
l1_lens
.
begin
(),
l1_it
);
auto
output_lens
=
compute_broadcasted_lens
(
l0_broadcasted_lens
,
l1_broadcasted_lens
);
l0_broadcasted_lens
=
output_lens
;
l0_broadcasted_lens
.
insert
(
l0_broadcasted_lens
.
end
(),
l0_it
,
l0_lens
.
end
());
l1_broadcasted_lens
=
output_lens
;
l1_broadcasted_lens
.
insert
(
l1_broadcasted_lens
.
end
(),
l1_it
,
l1_lens
.
end
());
if
(
l0_lens
!=
l0_broadcasted_lens
)
{
bl0
=
prog
.
add_instruction
(
op
::
multibroadcast
{
l0_broadcasted_lens
},
l0
);
}
if
(
l1_lens
!=
l1_broadcasted_lens
)
{
bl1
=
prog
.
add_instruction
(
op
::
multibroadcast
{
l1_broadcasted_lens
},
l1
);
}
}
auto
dot_res
=
prog
.
add_instruction
(
op
::
dot
{
1.0
f
,
0.0
f
},
bl0
,
bl1
);
int64_t
num_axis
=
static_cast
<
int64_t
>
(
dot_res
->
get_shape
().
lens
().
size
());
if
(
is_a_prepended
)
{
dot_res
=
prog
.
add_instruction
(
op
::
squeeze
{{
num_axis
-
2
}},
dot_res
);
--
num_axis
;
}
if
(
is_b_appended
)
{
dot_res
=
prog
.
add_instruction
(
op
::
squeeze
{{
num_axis
-
1
}},
dot_res
);
}
return
dot_res
;
}
instruction_ref
parse_batchnorm
(
const
std
::
string
&
,
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
args
)
{
...
...
@@ -573,15 +688,15 @@ struct onnx_parser
auto
&&
bias_floats
=
attributes
[
"bias"
].
floats
();
bias
=
std
::
vector
<
float
>
(
bias_floats
.
begin
(),
bias_floats
.
end
());
}
auto
input_
shape
=
args
.
front
()
->
get_shape
();
auto
input_
lens
=
args
.
front
()
->
get_shape
()
.
lens
()
;
auto
scale_val
=
prog
.
add_literal
(
scale
);
auto
bias_vals
=
prog
.
add_literal
(
migraphx
::
literal
{
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
bias
.
size
()}},
bias
});
auto
scale_tensor
=
prog
.
add_instruction
(
migraphx
::
op
::
scalar
{
input_
shape
},
scale_val
);
auto
scale_tensor
=
prog
.
add_instruction
(
migraphx
::
op
::
scalar
{
input_
lens
},
scale_val
);
auto
img_scaled
=
prog
.
add_instruction
(
migraphx
::
op
::
mul
{},
args
.
front
(),
scale_tensor
);
auto
bias_bcast
=
prog
.
add_instruction
(
migraphx
::
op
::
broadcast
{
1
,
input_
shape
},
bias_vals
);
auto
bias_bcast
=
prog
.
add_instruction
(
migraphx
::
op
::
broadcast
{
1
,
input_
lens
},
bias_vals
);
return
prog
.
add_instruction
(
migraphx
::
op
::
add
{},
img_scaled
,
bias_bcast
);
}
...
...
@@ -607,6 +722,11 @@ struct onnx_parser
auto
&&
pad_vals
=
attributes
[
"pads"
].
ints
();
pads
=
std
::
vector
<
int64_t
>
(
pad_vals
.
begin
(),
pad_vals
.
end
());
}
// check if padding is actually being done (at least one value is nonzero)
if
(
std
::
all_of
(
pads
.
begin
(),
pads
.
end
(),
[](
const
int
&
i
)
{
return
i
==
0
;
}))
{
return
prog
.
add_instruction
(
migraphx
::
op
::
identity
{},
args
.
front
());
}
if
(
contains
(
attributes
,
"value"
))
{
value
=
parse_value
(
attributes
.
at
(
"value"
)).
at
<
float
>
();
...
...
@@ -749,15 +869,17 @@ struct onnx_parser
{
auto
names
=
attributes
.
at
(
"activations"
).
strings
();
vec_names
.
clear
();
for_each
(
names
.
begin
(),
names
.
end
(),
[
&
](
auto
&
fn
)
{
vec_names
.
push_back
(
fn
);
});
vec_names
.
resize
(
names
.
size
());
std
::
copy
(
names
.
begin
(),
names
.
end
(),
vec_names
.
begin
());
}
for_each
(
vec_names
.
begin
(),
vec_names
.
end
(),
[
&
](
auto
&
fn
)
{
if
(
map_actv_funcs
.
count
(
fn
)
==
0
)
{
MIGRAPHX_THROW
(
"RNN: activation function "
+
std
::
string
(
fn
)
+
" not supported"
);
}
auto
name_it
=
std
::
find_if
(
vec_names
.
begin
(),
vec_names
.
end
(),
[
&
](
auto
&
name
)
{
return
(
map_actv_funcs
.
count
(
name
)
==
0
);
});
if
(
name_it
!=
vec_names
.
end
())
{
MIGRAPHX_THROW
(
"RNN: activation function "
+
std
::
string
(
*
name_it
)
+
" not supported"
);
}
// bidirectional case should have two activation functions.
// one is for forward, and the other is for reverse.
...
...
@@ -839,8 +961,7 @@ struct onnx_parser
auto
names
=
attributes
.
at
(
"activations"
).
strings
();
vec_names
.
clear
();
vec_names
.
resize
(
names
.
size
());
std
::
transform
(
names
.
begin
(),
names
.
end
(),
vec_names
.
begin
(),
[](
auto
&
str
)
{
return
str
;
});
std
::
copy
(
names
.
begin
(),
names
.
end
(),
vec_names
.
begin
());
}
// need 4 activation functions
...
...
@@ -878,12 +999,13 @@ struct onnx_parser
}
}
for_each
(
vec_names
.
begin
(),
vec_names
.
end
(),
[
&
](
auto
&
name
)
{
if
(
map_actv_funcs
.
count
(
name
)
==
0
)
{
MIGRAPHX_THROW
(
"GRU: activation function "
+
std
::
string
(
name
)
+
" not supported"
);
}
auto
name_it
=
std
::
find_if
(
vec_names
.
begin
(),
vec_names
.
end
(),
[
&
](
auto
&
name
)
{
return
(
map_actv_funcs
.
count
(
name
)
==
0
);
});
if
(
name_it
!=
vec_names
.
end
())
{
MIGRAPHX_THROW
(
"GRU: activation function "
+
std
::
string
(
*
name_it
)
+
" not supported"
);
}
std
::
vector
<
operation
>
vec_actv_funcs
(
vec_names
.
size
());
std
::
transform
(
vec_names
.
begin
(),
vec_names
.
end
(),
vec_actv_funcs
.
begin
(),
[
&
](
auto
&
name
)
{
...
...
@@ -920,6 +1042,178 @@ struct onnx_parser
return
{
hidden_states
,
last_output
};
}
std
::
vector
<
instruction_ref
>
parse_lstm
(
const
std
::
string
&
,
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
args
)
{
migraphx
::
shape
input_shape
=
args
[
0
]
->
get_shape
();
std
::
size_t
hidden_size
=
args
[
2
]
->
get_shape
().
lens
()[
2
];
if
(
contains
(
attributes
,
"hidden_size"
))
{
std
::
size_t
hidden_size_att
=
parse_value
(
attributes
.
at
(
"hidden_size"
)).
at
<
int
>
();
if
(
hidden_size
!=
hidden_size_att
)
{
MIGRAPHX_THROW
(
"LSTM: hidden size mismatch in input and attribute"
);
}
}
// Handling of direction to be added later
std
::
string
direction
{
"forward"
};
if
(
contains
(
attributes
,
"direction"
))
{
direction
=
attributes
.
at
(
"direction"
).
s
();
}
op
::
rnn_direction
dirct
=
op
::
rnn_direction
::
forward
;
if
(
direction
==
"bidirectional"
)
{
dirct
=
op
::
rnn_direction
::
bidirectional
;
}
else
if
(
direction
==
"reverse"
)
{
dirct
=
op
::
rnn_direction
::
reverse
;
}
else
if
(
direction
==
"forward"
)
{
dirct
=
op
::
rnn_direction
::
forward
;
}
else
{
MIGRAPHX_THROW
(
"LSTM: incorrect direction attribute"
);
}
std
::
vector
<
std
::
string
>
vec_names
=
{
"sigmoid"
,
"tanh"
,
"tanh"
};
if
(
contains
(
attributes
,
"activations"
))
{
auto
names
=
attributes
.
at
(
"activations"
).
strings
();
vec_names
.
clear
();
vec_names
.
resize
(
names
.
size
());
std
::
copy
(
names
.
begin
(),
names
.
end
(),
vec_names
.
begin
());
}
// need 6 activation functions for bidirectional directions
if
(
dirct
==
op
::
rnn_direction
::
bidirectional
)
{
// 6 activation functions are used in the bidirectional
// scenario. No spec is provided in onnx::operator. we
// use the algorithm that: if 1 actv function is provided,
// repeat 1st six times. If 2 actv functins are provided,
// repeat 2nd once, then repeat all three once
// if 3 actv funcs are provide, repeat all three once.
// the same algorithm is used for 4, 5, and 6 actv funcions
// provided. This may need change later
switch
(
vec_names
.
size
())
{
case
1
:
vec_names
=
{
vec_names
.
at
(
0
),
vec_names
.
at
(
0
),
vec_names
.
at
(
0
),
vec_names
.
at
(
0
),
vec_names
.
at
(
0
),
vec_names
.
at
(
0
)};
break
;
case
2
:
// repeat the 2nd actv func once, then repeat all three another time
vec_names
=
{
vec_names
.
at
(
0
),
vec_names
.
at
(
1
),
vec_names
.
at
(
1
),
vec_names
.
at
(
0
),
vec_names
.
at
(
1
),
vec_names
.
at
(
1
)};
break
;
case
3
:
// repeat all three actv funcs once
vec_names
=
{
vec_names
.
at
(
0
),
vec_names
.
at
(
1
),
vec_names
.
at
(
2
),
vec_names
.
at
(
0
),
vec_names
.
at
(
1
),
vec_names
.
at
(
2
)};
break
;
case
4
:
vec_names
=
{
vec_names
.
at
(
0
),
vec_names
.
at
(
1
),
vec_names
.
at
(
2
),
vec_names
.
at
(
3
),
vec_names
.
at
(
3
),
vec_names
.
at
(
3
)};
break
;
case
5
:
vec_names
=
{
vec_names
.
at
(
0
),
vec_names
.
at
(
1
),
vec_names
.
at
(
2
),
vec_names
.
at
(
3
),
vec_names
.
at
(
4
),
vec_names
.
at
(
4
)};
break
;
default:
break
;
}
}
else
{
switch
(
vec_names
.
size
())
{
case
1
:
vec_names
=
{
vec_names
.
at
(
0
),
vec_names
.
at
(
0
),
vec_names
.
at
(
0
)};
break
;
case
2
:
// repeat the 2nd actv func once, so we have 3 actv funcs
vec_names
=
{
vec_names
.
at
(
0
),
vec_names
.
at
(
1
),
vec_names
.
at
(
1
)};
break
;
default:
break
;
}
}
auto
name_it
=
std
::
find_if
(
vec_names
.
begin
(),
vec_names
.
end
(),
[
&
](
auto
&
name
)
{
return
(
map_actv_funcs
.
count
(
name
)
==
0
);
});
if
(
name_it
!=
vec_names
.
end
())
{
MIGRAPHX_THROW
(
"LSTM: activation function "
+
std
::
string
(
*
name_it
)
+
" not supported"
);
}
std
::
vector
<
operation
>
vec_actv_funcs
(
vec_names
.
size
());
std
::
transform
(
vec_names
.
begin
(),
vec_names
.
end
(),
vec_actv_funcs
.
begin
(),
[
&
](
auto
&
name
)
{
return
map_actv_funcs
[
name
];
});
float
clip
=
0.0
;
if
(
contains
(
attributes
,
"clip"
))
{
clip
=
parse_value
(
attributes
.
at
(
"clip"
)).
at
<
float
>
();
}
int
input_forget
=
0
;
if
(
contains
(
attributes
,
"input_forget"
))
{
input_forget
=
parse_value
(
attributes
.
at
(
"input_forget"
)).
at
<
int
>
();
}
// append undefined opeator to make 6 arguments
if
(
args
.
size
()
<
8
)
{
auto
ins
=
prog
.
add_instruction
(
op
::
undefined
{});
args
.
insert
(
args
.
end
(),
8
-
args
.
size
(),
ins
);
}
// first output for concatenation of hidden states
auto
hidden_states
=
prog
.
add_instruction
(
op
::
lstm
{
hidden_size
,
vec_actv_funcs
,
dirct
,
clip
,
input_forget
},
std
::
move
(
args
));
// second output for last lstm output
auto
last_output
=
prog
.
add_instruction
(
op
::
rnn_last_output
{},
hidden_states
);
// third output for last cell output
auto
last_cell_output
=
prog
.
add_instruction
(
op
::
lstm_last_cell_output
{},
hidden_states
);
return
{
hidden_states
,
last_output
,
last_cell_output
};
}
void
parse_from
(
std
::
istream
&
is
)
{
onnx
::
ModelProto
model
;
...
...
@@ -960,9 +1254,9 @@ struct onnx_parser
instructions
[
name
]
=
prog
.
add_parameter
(
name
,
s
);
}
}
for
(
auto
&&
p
:
nodes
)
for
(
auto
&&
output
:
graph
.
output
()
)
{
this
->
parse_node
(
p
.
first
);
this
->
parse_node
(
output
.
name
()
);
}
}
...
...
@@ -996,7 +1290,7 @@ struct onnx_parser
std
::
vector
<
instruction_ref
>
result
;
if
(
ops
.
count
(
node
.
op_type
())
==
0
)
{
result
.
push_back
(
prog
.
add_instruction
(
unknown
{
node
.
op_type
()},
args
));
result
.
push_back
(
prog
.
add_instruction
(
op
::
unknown
{
node
.
op_type
()},
args
));
}
else
{
...
...
@@ -1084,28 +1378,26 @@ struct onnx_parser
static
literal
parse_tensor
(
const
onnx
::
TensorProto
&
t
)
{
std
::
vector
<
std
::
size_t
>
dims
(
t
.
dims
().
begin
(),
t
.
dims
().
end
());
// in case of scalar constants in onnx file, use dims=1 to fill initializer data
if
(
dims
.
empty
())
{
dims
=
{
1
};
}
if
(
t
.
has_raw_data
())
{
const
std
::
string
&
s
=
t
.
raw_data
();
switch
(
t
.
data_type
())
{
case
onnx
::
TensorProto
::
UNDEFINED
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
FLOAT
:
return
literal
{{
shape
::
float_type
,
dims
}
,
s
.
data
()
}
;
case
onnx
::
TensorProto
::
FLOAT
:
return
create_
literal
(
shape
::
float_type
,
dims
,
s
.
data
()
)
;
case
onnx
::
TensorProto
::
UINT8
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
INT8
:
return
literal
{{
shape
::
int32_type
,
dims
},
s
.
data
()};
case
onnx
::
TensorProto
::
UINT16
:
return
literal
{{
shape
::
int32_type
,
dims
},
s
.
data
()};
case
onnx
::
TensorProto
::
INT16
:
return
literal
{{
shape
::
int32_type
,
dims
},
s
.
data
()};
case
onnx
::
TensorProto
::
INT32
:
return
literal
{{
shape
::
int32_type
,
dims
},
s
.
data
()};
case
onnx
::
TensorProto
::
INT64
:
return
literal
{{
shape
::
int64_type
,
dims
},
s
.
data
()};
case
onnx
::
TensorProto
::
INT8
:
return
create_literal
(
shape
::
int32_type
,
dims
,
s
.
data
());
case
onnx
::
TensorProto
::
UINT16
:
return
create_literal
(
shape
::
int32_type
,
dims
,
s
.
data
());
case
onnx
::
TensorProto
::
INT16
:
return
create_literal
(
shape
::
int32_type
,
dims
,
s
.
data
());
case
onnx
::
TensorProto
::
INT32
:
return
create_literal
(
shape
::
int32_type
,
dims
,
s
.
data
());
case
onnx
::
TensorProto
::
INT64
:
return
create_literal
(
shape
::
int64_type
,
dims
,
s
.
data
());
case
onnx
::
TensorProto
::
STRING
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
BOOL
:
return
literal
{{
shape
::
int32_type
,
dims
},
s
.
data
()};
case
onnx
::
TensorProto
::
FLOAT16
:
return
literal
{{
shape
::
half_type
,
dims
},
s
.
data
()};
case
onnx
::
TensorProto
::
DOUBLE
:
return
literal
{{
shape
::
double_type
,
dims
},
s
.
data
()};
case
onnx
::
TensorProto
::
BOOL
:
return
create_literal
(
shape
::
int32_type
,
dims
,
s
.
data
());
case
onnx
::
TensorProto
::
FLOAT16
:
return
create_literal
(
shape
::
half_type
,
dims
,
s
.
data
());
case
onnx
::
TensorProto
::
DOUBLE
:
return
create_literal
(
shape
::
double_type
,
dims
,
s
.
data
());
case
onnx
::
TensorProto
::
UINT32
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
UINT64
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
COMPLEX64
:
throw
std
::
runtime_error
(
""
);
...
...
@@ -1117,26 +1409,33 @@ struct onnx_parser
{
case
onnx
::
TensorProto
::
UNDEFINED
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
FLOAT
:
return
literal
{{
shape
::
float_type
,
dims
}
,
t
.
float_data
()
.
begin
(),
t
.
float_data
().
end
()}
;
return
create_
literal
(
shape
::
float_type
,
dims
,
t
.
float_data
()
)
;
case
onnx
::
TensorProto
::
UINT8
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
INT8
:
return
literal
{{
shape
::
int32_type
,
dims
}
,
t
.
int32_data
()
.
begin
(),
t
.
int32_data
().
end
()}
;
return
create_
literal
(
shape
::
int32_type
,
dims
,
t
.
int32_data
()
)
;
case
onnx
::
TensorProto
::
UINT16
:
return
literal
{{
shape
::
int32_type
,
dims
}
,
t
.
int32_data
()
.
begin
(),
t
.
int32_data
().
end
()}
;
return
create_
literal
(
shape
::
int32_type
,
dims
,
t
.
int32_data
()
)
;
case
onnx
::
TensorProto
::
INT16
:
return
literal
{{
shape
::
int32_type
,
dims
}
,
t
.
int32_data
()
.
begin
(),
t
.
int32_data
().
end
()}
;
return
create_
literal
(
shape
::
int32_type
,
dims
,
t
.
int32_data
()
)
;
case
onnx
::
TensorProto
::
INT32
:
return
literal
{{
shape
::
int32_type
,
dims
}
,
t
.
int32_data
()
.
begin
(),
t
.
int32_data
().
end
()}
;
return
create_
literal
(
shape
::
int32_type
,
dims
,
t
.
int32_data
()
)
;
case
onnx
::
TensorProto
::
INT64
:
return
literal
{{
shape
::
int64_type
,
dims
}
,
t
.
int64_data
()
.
begin
(),
t
.
int64_data
().
end
()}
;
return
create_
literal
(
shape
::
int64_type
,
dims
,
t
.
int64_data
()
)
;
case
onnx
::
TensorProto
::
STRING
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
BOOL
:
return
literal
{{
shape
::
int32_type
,
dims
}
,
t
.
int32_data
()
.
begin
(),
t
.
int32_data
().
end
()}
;
return
create_
literal
(
shape
::
int32_type
,
dims
,
t
.
int32_data
()
)
;
case
onnx
::
TensorProto
::
FLOAT16
:
return
literal
{{
shape
::
half_type
,
dims
},
t
.
float_data
().
begin
(),
t
.
float_data
().
end
()};
{
std
::
vector
<
uint16_t
>
data_uint16
(
t
.
int32_data
().
begin
(),
t
.
int32_data
().
end
());
std
::
vector
<
half
>
data_half
;
std
::
transform
(
data_uint16
.
begin
(),
data_uint16
.
end
(),
std
::
back_inserter
(
data_half
),
[](
uint16_t
raw_val
)
{
return
*
reinterpret_cast
<
half
*>
(
&
raw_val
);
});
return
create_literal
(
shape
::
half_type
,
dims
,
data_half
);
}
case
onnx
::
TensorProto
::
DOUBLE
:
return
literal
{
{
shape
::
double_type
,
dims
},
t
.
double_data
().
begin
(),
t
.
double_data
().
end
()};
return
create_literal
(
shape
::
double_type
,
dims
,
t
.
double_data
());
case
onnx
::
TensorProto
::
UINT32
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
UINT64
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
COMPLEX64
:
throw
std
::
runtime_error
(
""
);
...
...
@@ -1145,6 +1444,23 @@ struct onnx_parser
MIGRAPHX_THROW
(
"Invalid tensor type"
);
}
static
literal
create_literal
(
shape
::
type_t
shape_type
,
const
std
::
vector
<
size_t
>&
dims
,
const
char
*
data
)
{
// in case of scalar constants in onnx file, use dims=1 to fill initializer data
if
(
dims
.
empty
())
return
literal
{{
shape_type
},
data
};
return
literal
{{
shape_type
,
dims
},
data
};
}
template
<
class
T
,
MIGRAPHX_REQUIRES
(
not
std
::
is_pointer
<
T
>{})
>
static
literal
create_literal
(
shape
::
type_t
shape_type
,
const
std
::
vector
<
size_t
>&
dims
,
T
data
)
{
if
(
dims
.
empty
())
return
literal
{{
shape_type
},
data
.
begin
(),
data
.
end
()};
return
literal
{{
shape_type
,
dims
},
data
.
begin
(),
data
.
end
()};
}
static
shape
parse_type
(
const
onnx
::
TypeProto
&
t
)
{
shape
::
type_t
shape_type
{};
...
...
src/opt/common_header.hpp
deleted
100644 → 0
View file @
65ef35cd
#ifndef MIGRAPHX_GUARD_RTGLIB_COMMON_HEADER_HPP
#define MIGRAPHX_GUARD_RTGLIB_COMMON_HEADER_HPP
#include <migraphx/program.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/pass_config.hpp>
#include <migraphx/config.hpp>
#include <set>
#include <list>
#include <vector>
#include <queue>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
//#define MIGRAPHX_DEBUG_OPT
#ifdef MIGRAPHX_DEBUG_OPT
#define MIGRAPHX_DEBUG(s) s
#else
#define MIGRAPHX_DEBUG(s)
#endif // MIGRAPHX_DEBUG_OPT
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif // MIGRAPHX_GUARD_RTGLIB_COMMON_HEADER_HPP
src/opt/memory_coloring_impl.cpp
View file @
eb0d8fee
#include <migraphx/op/load.hpp>
#include "memory_coloring_impl.hpp"
namespace
migraphx
{
...
...
@@ -62,11 +63,11 @@ bool memory_coloring_impl::allocate(interval_ptr interval)
}
}
long
long
offset
=
0
;
std
::
size_t
offset
=
0
;
while
(
!
conflict_queue
.
empty
())
{
live_range
*
range
=
conflict_queue
.
top
();
long
long
iter_offset
=
range
->
offset
;
live_range
*
range
=
conflict_queue
.
top
();
std
::
size_t
iter_offset
=
range
->
offset
;
if
(
offset
>
iter_offset
)
{
offset
=
std
::
max
(
offset
,
iter_offset
+
range
->
size
);
...
...
@@ -96,7 +97,7 @@ void memory_coloring_impl::build()
if
(
num_of_instrs
==
0
)
return
;
int
cur_points
=
num_of_instrs
*
2
;
auto
cur_points
=
num_of_instrs
*
2
;
instruction_ref
iter
=
p_program
->
end
();
instruction_ref
begin
=
p_program
->
begin
();
std
::
vector
<
instruction_ref
>
dead_instrs
;
...
...
@@ -176,7 +177,7 @@ void memory_coloring_impl::build()
void
memory_coloring_impl
::
rewrite
()
{
std
::
vector
<
std
::
size_t
>
dims
;
dims
.
push_back
(
required_bytes
/
sizeof
(
float
));
dims
.
push_back
(
(
required_bytes
+
sizeof
(
float
)
-
1
)
/
sizeof
(
float
));
shape
s
=
{
shape
::
float_type
,
dims
};
instruction_ref
scratch_param
=
p_program
->
add_parameter
(
"scratch"
,
s
);
for
(
auto
ins
:
iterator_for
(
*
p_program
))
...
...
@@ -192,29 +193,19 @@ void memory_coloring_impl::rewrite()
continue
;
std
::
size_t
offset
=
0
;
if
(
interval
->
get_offset
()
=
=
invalid_offset
)
if
(
interval
->
get_offset
()
!
=
invalid_offset
)
{
assert
(
interval
->
result
.
bytes
()
==
0
);
offset
=
interval
->
get_offset
(
);
}
else
{
offset
=
interval
->
get_offset
(
);
assert
(
interval
->
result
.
bytes
()
==
0
);
}
if
(
is_allocate
(
ins
))
{
assert
(
!
ins
->
inputs
().
empty
());
p_program
->
replace_instruction
(
ins
,
op
::
load
{
ins
->
inputs
().
at
(
0
)
->
get_shape
(),
offset
},
scratch_param
);
}
else
if
(
is_literal
(
ins
))
{
#if 0
auto pre = p_program->add_literal(ins->lit);
bool pre_copy = (interval->get_begin() < earliest_end_point);
p_program->replace_instruction(
ins, write_literal{offset, pre_copy}, scratch_param, pre);
#endif
ins
,
op
::
load
{
ins
->
get_shape
(),
offset
},
scratch_param
);
}
}
}
...
...
src/opt/memory_coloring_impl.hpp
View file @
eb0d8fee
#ifndef MIGRAPHX_GUARD_RTGLIB_MEMORY_COLORING_IMPL_HPP
#define MIGRAPHX_GUARD_RTGLIB_MEMORY_COLORING_IMPL_HPP
#include "common_header.hpp"
#include <migraphx/program.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/pass_config.hpp>
#include <migraphx/config.hpp>
#include <set>
#include <list>
#include <vector>
#include <queue>
#ifdef MIGRAPHX_DEBUG_OPT
#define MIGRAPHX_DEBUG(s) s
#else
#define MIGRAPHX_DEBUG(s)
#endif // MIGRAPHX_DEBUG_OPT
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
static
const
in
t
invalid_offset
=
-
1
;
static
const
std
::
size_
t
invalid_offset
=
std
::
numeric_limits
<
std
::
size_t
>::
max
()
;
struct
live_range
{
in
t
begin
;
// begin point in the instruction stream.
int
end
;
// end point in the instruction stream.
long
long
offset
;
// offset to base pointer of allocated memory trunk.
int
vn
;
// value number that identifies this live_range.
long
long
size
;
// size of required memory in bytes
std
::
size_
t
begin
;
// begin point in the instruction stream.
std
::
size_t
end
;
// end point in the instruction stream.
std
::
size_t
offset
;
// offset to base pointer of allocated memory trunk.
std
::
size_t
vn
;
// value number that identifies this live_range.
std
::
size_t
size
;
// size of required memory in bytes
#ifdef MIGRAPHX_DEBUG_OPT
void
dump
();
#endif
...
...
@@ -30,9 +45,9 @@ struct live_interval
is_live_on_entry
=
false
;
}
void
add_use
(
in
t
use
)
{
use_points
.
push_front
(
use
);
}
in
t
get_begin
()
const
{
return
segment
.
begin
;
}
in
t
get_end
()
const
{
return
segment
.
end
;
}
void
add_use
(
std
::
size_
t
use
)
{
use_points
.
push_front
(
use
);
}
std
::
size_
t
get_begin
()
const
{
return
segment
.
begin
;
}
std
::
size_
t
get_end
()
const
{
return
segment
.
end
;
}
long
long
get_offset
()
const
{
return
segment
.
offset
;
}
#ifdef MIGRAPHX_DEBUG_OPT
...
...
@@ -40,9 +55,9 @@ struct live_interval
#endif
live_range
segment
;
in
t
id
;
std
::
list
<
in
t
>
use_points
;
in
t
def_point
;
std
::
size_
t
id
;
std
::
list
<
std
::
size_
t
>
use_points
;
std
::
size_
t
def_point
;
shape
result
;
bool
is_literal
;
bool
is_live_on_entry
;
...
...
@@ -96,8 +111,8 @@ struct memory_coloring_impl
{
if
((
range1
.
size
==
0
)
||
(
range2
.
size
==
0
))
return
false
;
long
long
end1
=
range1
.
offset
+
range1
.
size
-
1
;
long
long
end2
=
range2
.
offset
+
range2
.
size
-
1
;
auto
end1
=
range1
.
offset
+
range1
.
size
-
1
;
auto
end2
=
range2
.
offset
+
range2
.
size
-
1
;
return
((
end1
<
range2
.
offset
)
||
(
end2
<
range1
.
offset
));
}
void
verify
();
...
...
@@ -110,8 +125,8 @@ struct memory_coloring_impl
{
bool
operator
()(
const
interval_ptr
i1
,
const
interval_ptr
i2
)
const
{
int
len1
=
i1
->
get_end
()
-
i1
->
get_begin
();
int
len2
=
i2
->
get_end
()
-
i2
->
get_begin
();
auto
len1
=
i1
->
get_end
()
-
i1
->
get_begin
();
auto
len2
=
i2
->
get_end
()
-
i2
->
get_begin
();
if
(
len1
!=
len2
)
{
return
(
len1
<
len2
);
...
...
@@ -143,7 +158,7 @@ struct memory_coloring_impl
int
num_of_lives
;
int
max_value_number
;
long
long
required_bytes
;
std
::
size_t
required_bytes
;
// The earliest program point where an live interval ends.
int
earliest_end_point
;
// The latest program point where an live interval ends.
...
...
src/pass_manager.cpp
0 → 100644
View file @
eb0d8fee
#include <migraphx/program.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/target.hpp>
#include <migraphx/env.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/time.hpp>
#include <migraphx/iterator_for.hpp>
#include <iostream>
#include <sstream>
#include <algorithm>
#include <utility>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
void
run_passes
(
program
&
prog
,
const
std
::
vector
<
pass
>&
passes
,
tracer
trace
)
{
for
(
auto
&
p
:
passes
)
{
trace
(
"Pass: "
,
p
.
name
());
p
.
apply
(
prog
);
trace
(
prog
);
#ifndef NDEBUG
trace
(
"Validate ..."
);
auto
invalid
=
prog
.
validate
();
if
(
invalid
!=
prog
.
end
())
{
auto
index
=
std
::
distance
(
prog
.
begin
(),
invalid
);
MIGRAPHX_THROW
(
p
.
name
()
+
" pass produces invalid program at instruction "
+
std
::
to_string
(
index
)
+
": "
+
invalid
->
name
());
}
trace
();
#endif
}
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/program.cpp
View file @
eb0d8fee
#include <migraphx/program.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/op/identity.hpp>
#include <migraphx/target.hpp>
#include <migraphx/env.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/time.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/pass_manager.hpp>
#include <iostream>
#include <sstream>
#include <algorithm>
...
...
@@ -14,9 +16,6 @@
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_TRACE_COMPILE
)
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_TRACE_EVAL
)
struct
program_impl
{
// A list is used to keep references to an instruction stable
...
...
@@ -57,18 +56,23 @@ static void print_instruction(std::ostream& os,
}
template
<
class
F
>
static
void
print_program
(
std
::
ostream
&
os
,
const
program
&
p
,
F
annonate
)
static
void
print_program
(
const
program
&
p
,
F
print_func
)
{
std
::
unordered_map
<
instruction_ref
,
std
::
string
>
names
;
int
count
=
0
;
for
(
auto
ins
:
iterator_for
(
p
))
{
std
::
string
var_name
=
"@"
+
std
::
to_string
(
count
)
;
std
::
string
var_name
;
if
(
ins
->
name
()
==
"@param"
)
{
var_name
=
any_cast
<
builtin
::
param
>
(
ins
->
get_operator
()).
parameter
;
}
else
{
var_name
=
"@"
+
std
::
to_string
(
count
);
count
++
;
}
names
.
emplace
(
ins
,
var_name
);
// TODO: Use all_of
...
...
@@ -78,21 +82,77 @@ static void print_program(std::ostream& os, const program& p, F annonate)
(
void
)
arg
;
}
print_instruction
(
os
,
ins
,
names
);
annonate
(
ins
,
names
);
os
<<
std
::
endl
;
count
++
;
print_func
(
ins
,
names
);
}
}
program
::
program
()
:
impl
(
std
::
make_unique
<
program_impl
>
())
{}
program
::
program
(
program
&&
)
noexcept
=
default
;
program
&
program
::
operator
=
(
program
&&
)
noexcept
=
default
;
program
::~
program
()
noexcept
=
default
;
program
::~
program
()
noexcept
=
default
;
// copy constructor
program
::
program
(
const
program
&
p
)
{
assign
(
p
);
}
// copy assignment operator
program
&
program
::
operator
=
(
program
p
)
{
std
::
swap
(
p
.
impl
,
this
->
impl
);
return
*
this
;
}
void
program
::
assign
(
const
program
&
p
)
{
// clean the current program
if
(
!
impl
)
{
impl
=
std
::
make_unique
<
program_impl
>
();
}
else
if
(
!
impl
->
instructions
.
empty
())
{
impl
->
instructions
.
clear
();
}
impl
->
ctx
=
p
.
impl
->
ctx
;
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>
ins_map
;
for
(
auto
ins
:
iterator_for
(
p
))
{
instruction_ref
copy_ins
{};
if
(
ins
->
name
()
==
"@literal"
)
{
auto
l
=
ins
->
get_literal
();
copy_ins
=
impl
->
instructions
.
insert
(
impl
->
instructions
.
end
(),
instruction
{
l
});
}
else
if
(
ins
->
name
()
==
"@param"
)
{
auto
&&
name
=
any_cast
<
builtin
::
param
>
(
ins
->
get_operator
()).
parameter
;
auto
s
=
ins
->
get_shape
();
copy_ins
=
impl
->
instructions
.
insert
(
impl
->
instructions
.
end
(),
{
builtin
::
param
{
name
},
std
::
move
(
s
),
{}});
}
else
if
(
ins
->
name
()
==
"@outline"
)
{
auto
s
=
ins
->
get_shape
();
copy_ins
=
impl
->
instructions
.
insert
(
impl
->
instructions
.
end
(),
{
builtin
::
outline
{
s
},
s
,
{}});
}
else
{
// retrieve its mapped input
auto
inputs
=
ins
->
inputs
();
// ensure all inputs have its corresponding copy instructions
assert
(
std
::
all_of
(
inputs
.
begin
(),
inputs
.
end
(),
[
&
](
auto
i
)
{
return
ins_map
.
count
(
i
)
>
0
;
}));
std
::
vector
<
instruction_ref
>
copy_inputs
(
inputs
.
size
());
std
::
transform
(
inputs
.
begin
(),
inputs
.
end
(),
copy_inputs
.
begin
(),
[
&
](
auto
i
)
{
return
ins_map
[
i
];
});
copy_ins
=
add_instruction
(
ins
->
get_operator
(),
copy_inputs
);
}
ins_map
[
ins
]
=
copy_ins
;
}
}
instruction_ref
program
::
add_instruction
(
const
operation
&
op
,
std
::
vector
<
instruction_ref
>
args
)
{
...
...
@@ -106,11 +166,9 @@ instruction_ref program::insert_instruction(instruction_ref ins,
args
.
begin
(),
args
.
end
(),
[
&
](
instruction_ref
x
)
{
return
has_instruction
(
x
);
})
&&
"Argument is not an exisiting instruction"
);
assert
(
not
starts_with
(
op
.
name
(),
"@"
));
// TODO: Use move
shape
r
=
compute_shape
(
op
,
args
);
auto
result
=
impl
->
instructions
.
insert
(
ins
,
{
op
,
r
,
std
::
move
(
args
)});
instruction
::
backreference
(
result
);
// assert(result->inputs() == args);
assert
(
result
->
valid
(
begin
()));
return
result
;
}
...
...
@@ -295,23 +353,7 @@ void program::compile(const target& t, tracer trace)
trace
=
tracer
{
std
::
cout
};
trace
(
*
this
);
trace
();
for
(
auto
&&
p
:
t
.
get_passes
(
this
->
impl
->
ctx
))
{
trace
(
"Pass: "
,
p
.
name
());
p
.
apply
(
*
this
);
trace
(
*
this
);
#ifndef NDEBUG
trace
(
"Validate ..."
);
auto
invalid
=
this
->
validate
();
if
(
invalid
!=
impl
->
instructions
.
end
())
{
auto
index
=
std
::
distance
(
impl
->
instructions
.
begin
(),
invalid
);
MIGRAPHX_THROW
(
p
.
name
()
+
" pass produces invalid program at instruction "
+
std
::
to_string
(
index
)
+
": "
+
invalid
->
name
());
}
trace
();
#endif
}
run_passes
(
*
this
,
t
.
get_passes
(
this
->
impl
->
ctx
),
trace
);
auto
invalid
=
this
->
validate
();
if
(
invalid
!=
impl
->
instructions
.
end
())
{
...
...
@@ -348,13 +390,17 @@ argument generic_eval(const program& p,
}
else
if
(
ins
->
name
()
==
"@param"
)
{
results
.
emplace
(
ins
,
trace
(
ins
,
[
&
]
{
auto
param_name
=
any_cast
<
builtin
::
param
>
(
ins
->
get_operator
()).
parameter
;
if
(
not
contains
(
params
,
param_name
))
MIGRAPHX_THROW
(
"Parameter not found: "
+
param_name
);
return
params
.
at
(
param_name
);
}));
results
.
emplace
(
ins
,
trace
(
ins
,
[
&
]
{
auto
param_name
=
any_cast
<
builtin
::
param
>
(
ins
->
get_operator
()).
parameter
;
if
(
not
contains
(
params
,
param_name
))
MIGRAPHX_THROW
(
"Parameter not found: "
+
param_name
);
auto
param
=
params
.
at
(
param_name
);
if
(
param
.
get_shape
()
!=
ins
->
get_shape
())
MIGRAPHX_THROW
(
"Incorrect shape {"
+
to_string
(
param
.
get_shape
())
+
"} for parameter: "
+
param_name
);
return
param
;
}));
}
else
if
(
ins
->
name
()
==
"@outline"
)
{
...
...
@@ -391,13 +437,20 @@ argument program::eval(std::unordered_map<std::string, argument> params) const
#else
auto
check_context
=
[](
auto
f
)
{
return
f
();
};
#endif
if
(
enabled
(
MIGRAPHX_TRACE_EVAL
{}))
auto
trace_level
=
value_of
(
MIGRAPHX_TRACE_EVAL
{});
if
(
trace_level
>
0
)
{
return
generic_eval
(
*
this
,
ctx
,
std
::
move
(
params
),
[
&
](
auto
&
ins
,
auto
f
)
{
ctx
.
finish
();
std
::
cout
<<
"Run instruction: "
;
this
->
debug_print
(
ins
);
return
check_context
(
f
);
auto
result
=
check_context
(
f
);
ctx
.
finish
();
if
(
trace_level
>
1
and
ins
->
name
().
front
()
!=
'@'
and
ins
->
name
()
!=
"load"
)
std
::
cout
<<
"Ouput: "
<<
result
<<
std
::
endl
;
return
result
;
});
}
else
...
...
@@ -475,10 +528,12 @@ void program::perf_report(std::ostream& os, std::size_t n, parameter_map params)
double
calculate_overhead_time
=
total_time
-
total_instruction_time
;
double
calculate_overhead_percent
=
calculate_overhead_time
*
100.0
/
total_time
;
print_program
(
os
,
*
this
,
[
&
](
auto
ins
,
auto
&&
)
{
print_program
(
*
this
,
[
&
](
auto
ins
,
const
auto
&
names
)
{
print_instruction
(
std
::
cout
,
ins
,
names
);
double
avg
=
common_average
(
ins_vec
[
ins
]);
double
percent
=
std
::
ceil
(
100.0
*
avg
/
total_instruction_time
);
os
<<
": "
<<
avg
<<
"ms, "
<<
percent
<<
"%"
;
os
<<
std
::
endl
;
});
os
<<
std
::
endl
;
...
...
@@ -505,8 +560,18 @@ void program::perf_report(std::ostream& os, std::size_t n, parameter_map params)
void
program
::
debug_print
()
const
{
std
::
cout
<<
*
this
<<
std
::
endl
;
}
void
program
::
debug_print
(
instruction_ref
ins
)
const
{
if
(
ins
==
this
->
end
())
{
std
::
cout
<<
"End instruction"
<<
std
::
endl
;
return
;
}
if
(
not
has_instruction
(
ins
))
{
std
::
cout
<<
"Instruction not part of program"
<<
std
::
endl
;
return
;
}
std
::
stringstream
ss
;
print_program
(
ss
,
*
this
,
[
&
](
auto
x
,
auto
&
&
names
)
{
print_program
(
*
this
,
[
&
](
auto
x
,
const
auto
&
names
)
{
if
(
x
==
ins
)
{
print_instruction
(
std
::
cout
,
x
,
names
);
...
...
@@ -521,17 +586,55 @@ void program::debug_print(const std::vector<instruction_ref>& inss) const
std
::
cout
<<
std
::
endl
;
}
static
std
::
string
enclose_name
(
const
std
::
string
&
name
)
{
return
'"'
+
replace_string
(
name
,
"
\"
"
,
"
\\\"
"
)
+
'"'
;
}
void
program
::
print_graph
(
std
::
ostream
&
os
)
const
{
os
<<
"digraph {"
<<
std
::
endl
;
os
<<
"
\t
rankdir=LR;"
<<
std
::
endl
;
print_program
(
*
this
,
[
&
](
auto
ins
,
const
auto
&
names
)
{
os
<<
"
\t
"
<<
enclose_name
(
names
.
at
(
ins
))
<<
"[label="
<<
enclose_name
(
to_string
(
ins
->
get_operator
()))
<<
"];"
;
os
<<
std
::
endl
;
if
(
!
ins
->
inputs
().
empty
())
{
for
(
auto
&&
arg
:
ins
->
inputs
())
{
os
<<
"
\t
"
<<
enclose_name
(
names
.
at
(
arg
))
<<
" -> "
<<
enclose_name
(
names
.
at
(
ins
));
os
<<
"[label="
<<
enclose_name
(
to_string
(
ins
->
get_shape
()))
<<
"];"
;
os
<<
std
::
endl
;
}
}
});
os
<<
"}"
<<
std
::
endl
;
}
void
program
::
dry_run
(
std
::
unordered_map
<
std
::
string
,
argument
>
params
)
const
{
auto
&
ctx
=
this
->
impl
->
ctx
;
generic_eval
(
*
this
,
ctx
,
std
::
move
(
params
),
[](
auto
&&
...)
{
return
argument
{};
});
}
void
program
::
annotate
(
std
::
ostream
&
os
,
std
::
function
<
void
(
instruction_ref
)
>
a
)
const
{
print_program
(
*
this
,
[
&
](
auto
ins
,
const
auto
&
names
)
{
print_instruction
(
os
,
ins
,
names
);
a
(
ins
);
os
<<
std
::
endl
;
});
}
bool
operator
==
(
const
program
&
x
,
const
program
&
y
)
{
return
to_string
(
x
)
==
to_string
(
y
);
}
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
program
&
p
)
{
print_program
(
os
,
p
,
[](
auto
&&
...)
{});
print_program
(
p
,
[
&
](
auto
ins
,
const
auto
&
names
)
{
print_instruction
(
os
,
ins
,
names
);
os
<<
std
::
endl
;
});
return
os
;
}
...
...
src/propagate_constant.cpp
0 → 100644
View file @
eb0d8fee
#include <migraphx/propagate_constant.hpp>
#include <migraphx/program.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/functional.hpp>
#include <unordered_set>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
bool
skip_propogate
(
instruction_ref
ins
)
{
if
(
ins
->
name
()
==
"@literal"
)
return
true
;
auto
&&
s
=
ins
->
get_shape
();
if
(
s
.
broadcasted
()
and
not
s
.
scalar
())
return
true
;
if
(
s
.
scalar
()
and
s
.
elements
()
!=
1
)
return
true
;
return
false
;
}
void
propagate_constant
::
apply
(
program
&
p
)
const
{
for
(
auto
i
:
iterator_for
(
p
))
{
if
(
i
->
name
()
!=
"@literal"
)
continue
;
if
(
i
->
outputs
().
empty
())
continue
;
fix
([
&
](
auto
self
,
auto
ins
)
{
std
::
unordered_set
<
instruction_ref
>
children
(
ins
->
outputs
().
begin
(),
ins
->
outputs
().
end
());
for
(
auto
child
:
children
)
{
if
(
skip_propogate
(
child
))
{
self
(
child
);
continue
;
}
auto
r
=
child
->
eval
();
if
(
not
r
.
empty
())
{
assert
(
r
.
get_shape
()
==
child
->
get_shape
());
auto
l
=
p
.
add_literal
(
r
.
get_shape
(),
r
.
data
());
self
(
p
.
replace_instruction
(
child
,
l
));
}
}
})(
i
);
}
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/py/CMakeLists.txt
View file @
eb0d8fee
option
(
MIGRAPHX_ENABLE_PYTHON
"Enable python bindings"
ON
)
if
(
MIGRAPHX_ENABLE_PYTHON
)
find_program
(
DEFAULT_PYTHON_EXE python
)
if
(
DEFAULT_PYTHON_EXE
)
set
(
PYTHON_EXECUTABLE
${
DEFAULT_PYTHON_EXE
}
CACHE PATH
"Path to python executable"
)
endif
()
find_package
(
pybind11 REQUIRED
)
pybind11_add_module
(
migraphx_py migraphx_py.cpp
)
set_target_properties
(
migraphx_py PROPERTIES
OUTPUT_NAME migraphx
C_VISIBILITY_PRESET hidden
CXX_VISIBILITY_PRESET hidden
)
target_link_libraries
(
migraphx_py PRIVATE migraphx migraphx_onnx migraphx_cpu
)
if
(
MIGRAPHX_ENABLE_GPU
)
target_link_libraries
(
migraphx_py PRIVATE migraphx_gpu
)
target_compile_definitions
(
migraphx_py PRIVATE -DHAVE_GPU
)
endif
()
find_program
(
DEFAULT_PYTHON_EXE python
)
if
(
DEFAULT_PYTHON_EXE
)
set
(
PYTHON_EXECUTABLE
${
DEFAULT_PYTHON_EXE
}
CACHE PATH
"Path to python executable"
)
endif
()
find_package
(
pybind11 REQUIRED
)
pybind11_add_module
(
migraphx_py migraphx_py.cpp
)
set_target_properties
(
migraphx_py PROPERTIES
OUTPUT_NAME migraphx
C_VISIBILITY_PRESET hidden
CXX_VISIBILITY_PRESET hidden
)
if
(
MIGRAPHX_ENABLE_TF
)
target_link_libraries
(
migraphx_py PRIVATE migraphx migraphx_tf migraphx_cpu
)
target_compile_definitions
(
migraphx_py PRIVATE -DENABLE_TF
)
else
()
target_link_libraries
(
migraphx_py PRIVATE migraphx migraphx_onnx migraphx_cpu
)
endif
()
if
(
MIGRAPHX_ENABLE_GPU
)
target_link_libraries
(
migraphx_py PRIVATE migraphx_gpu
)
target_compile_definitions
(
migraphx_py PRIVATE -DHAVE_GPU
)
endif
()
rocm_install_targets
(
TARGETS migraphx_py
)
endif
()
Prev
1
2
3
4
5
6
7
8
9
10
…
16
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment