Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
583c76f2
Commit
583c76f2
authored
Jul 10, 2018
by
Scott Thornton
Browse files
Merge branch 'master' into mnist2
parents
06fb0905
603adbe6
Changes
11
Show whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
370 additions
and
16 deletions
+370
-16
src/include/migraph/iterator_for.hpp
src/include/migraph/iterator_for.hpp
+31
-0
src/include/migraph/pass.hpp
src/include/migraph/pass.hpp
+204
-0
src/include/migraph/target.hpp
src/include/migraph/target.hpp
+13
-7
src/program.cpp
src/program.cpp
+8
-1
src/targets/cpu/cpu_target.cpp
src/targets/cpu/cpu_target.cpp
+10
-2
src/targets/cpu/include/migraph/cpu/cpu_target.hpp
src/targets/cpu/include/migraph/cpu/cpu_target.hpp
+1
-1
src/targets/miopen/include/migraph/miopen/miopen_target.hpp
src/targets/miopen/include/migraph/miopen/miopen_target.hpp
+1
-1
src/targets/miopen/miopen_target.cpp
src/targets/miopen/miopen_target.cpp
+9
-2
test/eval_test.cpp
test/eval_test.cpp
+67
-1
tools/include/pass.hpp
tools/include/pass.hpp
+23
-0
tools/include/target.hpp
tools/include/target.hpp
+3
-1
No files found.
src/include/migraph/iterator_for.hpp
0 → 100644
View file @
583c76f2
#ifndef MIGRAPH_GUARD_RTGLIB_ITERATOR_FOR_HPP
#define MIGRAPH_GUARD_RTGLIB_ITERATOR_FOR_HPP
namespace
migraph
{
template
<
class
T
>
struct
iterator_for_range
{
T
*
base
;
using
base_iterator
=
decltype
(
base
->
begin
());
struct
iterator
{
base_iterator
i
;
base_iterator
operator
*
()
{
return
i
;
}
base_iterator
operator
++
()
{
return
++
i
;
}
bool
operator
!=
(
const
iterator
&
rhs
)
{
return
i
!=
rhs
.
i
;
}
};
iterator
begin
()
{
return
{
base
->
begin
()};
}
iterator
end
()
{
return
{
base
->
end
()};
}
};
template
<
class
T
>
iterator_for_range
<
T
>
iterator_for
(
T
&
x
)
{
return
{
&
x
};
}
}
// namespace migraph
#endif
src/include/migraph/pass.hpp
0 → 100644
View file @
583c76f2
#ifndef MIGRAPH_GUARD_PASS_HPP
#define MIGRAPH_GUARD_PASS_HPP
#include <string>
#include <functional>
#include <memory>
#include <type_traits>
#include <utility>
namespace
migraph
{
struct
program
;
/*
* Type-erased interface for:
*
* struct pass
* {
* std::string name() const;
* void apply(program & p) const;
* };
*
*/
struct
pass
{
// Constructors
pass
()
=
default
;
template
<
typename
PrivateDetailTypeErasedT
>
pass
(
PrivateDetailTypeErasedT
value
)
:
private_detail_te_handle_mem_var
(
std
::
make_shared
<
private_detail_te_handle_type
<
typename
std
::
remove_reference
<
PrivateDetailTypeErasedT
>::
type
>>
(
std
::
forward
<
PrivateDetailTypeErasedT
>
(
value
)))
{
}
// Assignment
template
<
typename
PrivateDetailTypeErasedT
>
pass
&
operator
=
(
PrivateDetailTypeErasedT
value
)
{
if
(
private_detail_te_handle_mem_var
.
unique
())
*
private_detail_te_handle_mem_var
=
std
::
forward
<
PrivateDetailTypeErasedT
>
(
value
);
else
if
(
!
private_detail_te_handle_mem_var
)
private_detail_te_handle_mem_var
=
std
::
make_shared
<
PrivateDetailTypeErasedT
>
(
std
::
forward
<
PrivateDetailTypeErasedT
>
(
value
));
return
*
this
;
}
// Cast
template
<
typename
PrivateDetailTypeErasedT
>
PrivateDetailTypeErasedT
*
any_cast
()
{
return
private_detail_te_get_handle
().
type
()
==
typeid
(
PrivateDetailTypeErasedT
)
?
std
::
addressof
(
static_cast
<
private_detail_te_handle_type
<
typename
std
::
remove_cv
<
PrivateDetailTypeErasedT
>::
type
>&>
(
private_detail_te_get_handle
())
.
private_detail_te_value
)
:
nullptr
;
}
template
<
typename
PrivateDetailTypeErasedT
>
const
typename
std
::
remove_cv
<
PrivateDetailTypeErasedT
>::
type
*
any_cast
()
const
{
return
private_detail_te_get_handle
().
type
()
==
typeid
(
PrivateDetailTypeErasedT
)
?
std
::
addressof
(
static_cast
<
const
private_detail_te_handle_type
<
typename
std
::
remove_cv
<
PrivateDetailTypeErasedT
>::
type
>&>
(
private_detail_te_get_handle
())
.
private_detail_te_value
)
:
nullptr
;
}
const
std
::
type_info
&
type_id
()
const
{
if
(
private_detail_te_handle_empty
())
return
typeid
(
std
::
nullptr_t
);
else
return
private_detail_te_get_handle
().
type
();
}
std
::
string
name
()
const
{
assert
((
*
this
).
private_detail_te_handle_mem_var
);
return
(
*
this
).
private_detail_te_get_handle
().
name
();
}
void
apply
(
program
&
p
)
const
{
assert
((
*
this
).
private_detail_te_handle_mem_var
);
return
(
*
this
).
private_detail_te_get_handle
().
apply
(
p
);
}
private:
struct
private_detail_te_handle_base_type
{
virtual
~
private_detail_te_handle_base_type
()
{}
virtual
std
::
shared_ptr
<
private_detail_te_handle_base_type
>
clone
()
const
=
0
;
virtual
const
std
::
type_info
&
type
()
const
=
0
;
virtual
std
::
string
name
()
const
=
0
;
virtual
void
apply
(
program
&
p
)
const
=
0
;
};
template
<
typename
PrivateDetailTypeErasedT
>
struct
private_detail_te_handle_type
:
private_detail_te_handle_base_type
{
template
<
typename
PrivateDetailTypeErasedU
=
PrivateDetailTypeErasedT
>
private_detail_te_handle_type
(
PrivateDetailTypeErasedT
value
,
typename
std
::
enable_if
<
std
::
is_reference
<
PrivateDetailTypeErasedU
>::
value
>::
type
*
=
nullptr
)
:
private_detail_te_value
(
value
)
{
}
template
<
typename
PrivateDetailTypeErasedU
=
PrivateDetailTypeErasedT
>
private_detail_te_handle_type
(
PrivateDetailTypeErasedT
value
,
typename
std
::
enable_if
<!
std
::
is_reference
<
PrivateDetailTypeErasedU
>::
value
,
int
>::
type
*
=
nullptr
)
noexcept
:
private_detail_te_value
(
std
::
move
(
value
))
{
}
std
::
shared_ptr
<
private_detail_te_handle_base_type
>
clone
()
const
override
{
return
std
::
make_shared
<
private_detail_te_handle_type
>
(
private_detail_te_value
);
}
const
std
::
type_info
&
type
()
const
override
{
return
typeid
(
private_detail_te_value
);
}
std
::
string
name
()
const
override
{
return
private_detail_te_value
.
name
();
}
void
apply
(
program
&
p
)
const
override
{
return
private_detail_te_value
.
apply
(
p
);
}
PrivateDetailTypeErasedT
private_detail_te_value
;
};
template
<
typename
PrivateDetailTypeErasedT
>
struct
private_detail_te_handle_type
<
std
::
reference_wrapper
<
PrivateDetailTypeErasedT
>>
:
private_detail_te_handle_type
<
PrivateDetailTypeErasedT
&>
{
private_detail_te_handle_type
(
std
::
reference_wrapper
<
PrivateDetailTypeErasedT
>
ref
)
:
private_detail_te_handle_type
<
PrivateDetailTypeErasedT
&>
(
ref
.
get
())
{
}
};
bool
private_detail_te_handle_empty
()
const
{
return
private_detail_te_handle_mem_var
==
nullptr
;
}
const
private_detail_te_handle_base_type
&
private_detail_te_get_handle
()
const
{
assert
(
private_detail_te_handle_mem_var
!=
nullptr
);
return
*
private_detail_te_handle_mem_var
;
}
private_detail_te_handle_base_type
&
private_detail_te_get_handle
()
{
assert
(
private_detail_te_handle_mem_var
!=
nullptr
);
if
(
!
private_detail_te_handle_mem_var
.
unique
())
private_detail_te_handle_mem_var
=
private_detail_te_handle_mem_var
->
clone
();
return
*
private_detail_te_handle_mem_var
;
}
std
::
shared_ptr
<
private_detail_te_handle_base_type
>
private_detail_te_handle_mem_var
;
};
template
<
typename
ValueType
>
inline
const
ValueType
*
any_cast
(
const
pass
*
x
)
{
return
x
->
any_cast
<
ValueType
>
();
}
template
<
typename
ValueType
>
inline
ValueType
*
any_cast
(
pass
*
x
)
{
return
x
->
any_cast
<
ValueType
>
();
}
template
<
typename
ValueType
>
inline
ValueType
&
any_cast
(
pass
&
x
)
{
auto
*
y
=
x
.
any_cast
<
typename
std
::
remove_reference
<
ValueType
>::
type
>
();
if
(
y
==
nullptr
)
throw
std
::
bad_cast
();
return
*
y
;
}
template
<
typename
ValueType
>
inline
const
ValueType
&
any_cast
(
const
pass
&
x
)
{
const
auto
*
y
=
x
.
any_cast
<
typename
std
::
remove_reference
<
ValueType
>::
type
>
();
if
(
y
==
nullptr
)
throw
std
::
bad_cast
();
return
*
y
;
}
}
// namespace migraph
#endif
src/include/migraph/target.hpp
View file @
583c76f2
...
...
@@ -6,7 +6,9 @@
#include <memory>
#include <type_traits>
#include <utility>
#include <vector>
#include <migraph/context.hpp>
#include <migraph/pass.hpp>
namespace
migraph
{
...
...
@@ -18,7 +20,7 @@ struct program;
* struct target
* {
* std::string name() const;
*
void apply(program & p
) const;
*
std::vector<pass> get_passes(context& ctx
) const;
* context get_context() const;
* };
*
...
...
@@ -87,10 +89,10 @@ struct target
return
(
*
this
).
private_detail_te_get_handle
().
name
();
}
void
apply
(
program
&
p
)
const
std
::
vector
<
pass
>
get_passes
(
context
&
ctx
)
const
{
assert
((
*
this
).
private_detail_te_handle_mem_var
);
return
(
*
this
).
private_detail_te_get_handle
().
apply
(
p
);
return
(
*
this
).
private_detail_te_get_handle
().
get_passes
(
ctx
);
}
context
get_context
()
const
...
...
@@ -107,7 +109,7 @@ struct target
virtual
const
std
::
type_info
&
type
()
const
=
0
;
virtual
std
::
string
name
()
const
=
0
;
virtual
void
apply
(
program
&
p
)
const
=
0
;
virtual
std
::
vector
<
pass
>
get_passes
(
context
&
ctx
)
const
=
0
;
virtual
context
get_context
()
const
=
0
;
};
...
...
@@ -141,7 +143,11 @@ struct target
std
::
string
name
()
const
override
{
return
private_detail_te_value
.
name
();
}
void
apply
(
program
&
p
)
const
override
{
return
private_detail_te_value
.
apply
(
p
);
}
std
::
vector
<
pass
>
get_passes
(
context
&
ctx
)
const
override
{
return
private_detail_te_value
.
get_passes
(
ctx
);
}
context
get_context
()
const
override
{
return
private_detail_te_value
.
get_context
();
}
...
...
src/program.cpp
View file @
583c76f2
...
...
@@ -111,7 +111,14 @@ void program::compile(const target& t)
{
assert
(
this
->
validate
()
!=
impl
->
instructions
.
end
());
this
->
impl
->
ctx
=
t
.
get_context
();
t
.
apply
(
*
this
);
for
(
auto
&&
p
:
t
.
get_passes
(
this
->
impl
->
ctx
))
{
p
.
apply
(
*
this
);
#ifndef NDEBUG
if
(
this
->
validate
()
==
impl
->
instructions
.
end
())
MIGRAPH_THROW
(
p
.
name
()
+
" pass produces invalid program"
);
#endif
}
if
(
this
->
validate
()
==
impl
->
instructions
.
end
())
MIGRAPH_THROW
(
"Invalid program from compilation"
);
}
...
...
src/targets/cpu/cpu_target.cpp
View file @
583c76f2
...
...
@@ -4,6 +4,7 @@
#include <migraph/dfor.hpp>
#include <migraph/operators.hpp>
#include <migraph/shape_for_each.hpp>
#include <migraph/iterator_for.hpp>
namespace
migraph
{
namespace
cpu
{
...
...
@@ -491,7 +492,7 @@ struct cpu_apply
void
apply
()
{
init
();
for
(
auto
it
=
prog
->
begin
();
it
!=
prog
->
end
();
it
++
)
for
(
auto
it
:
iterator_for
(
*
prog
)
)
{
if
(
it
->
op
.
name
()
==
"activation"
)
{
...
...
@@ -538,9 +539,16 @@ struct cpu_apply
}
};
struct
cpu_pass
{
std
::
string
name
()
const
{
return
"cpu::pass"
;
}
void
apply
(
program
&
p
)
const
{
cpu_apply
{
&
p
}.
apply
();
}
};
std
::
string
cpu_target
::
name
()
const
{
return
"cpu"
;
}
void
cpu_target
::
apply
(
program
&
p
)
const
{
cpu_apply
{
&
p
}.
apply
()
;
}
std
::
vector
<
pass
>
cpu_target
::
get_passes
(
context
&
)
const
{
return
{
cpu_pass
{}}
;
}
}
// namespace cpu
...
...
src/targets/cpu/include/migraph/cpu/cpu_target.hpp
View file @
583c76f2
...
...
@@ -9,7 +9,7 @@ namespace cpu {
struct
cpu_target
{
std
::
string
name
()
const
;
void
apply
(
program
&
p
)
const
;
std
::
vector
<
pass
>
get_passes
(
context
&
ctx
)
const
;
context
get_context
()
const
{
return
{};
}
};
...
...
src/targets/miopen/include/migraph/miopen/miopen_target.hpp
View file @
583c76f2
...
...
@@ -9,7 +9,7 @@ namespace miopen {
struct
miopen_target
{
std
::
string
name
()
const
;
void
apply
(
program
&
p
)
const
;
std
::
vector
<
pass
>
get_passes
(
context
&
ctx
)
const
;
context
get_context
()
const
;
};
...
...
src/targets/miopen/miopen_target.cpp
View file @
583c76f2
...
...
@@ -299,9 +299,16 @@ struct miopen_apply
}
};
std
::
string
miopen_target
::
name
()
const
{
return
"miopen"
;
}
struct
miopen_pass
{
std
::
string
name
()
const
{
return
"miopen::pass"
;
}
void
apply
(
program
&
p
)
const
{
miopen_apply
{
&
p
}.
apply
();
}
};
std
::
vector
<
pass
>
miopen_target
::
get_passes
(
context
&
)
const
{
return
{
miopen_pass
{}};
}
void
miopen_target
::
apply
(
program
&
p
)
const
{
miopen_apply
{
&
p
}.
apply
()
;
}
std
::
string
miopen_target
::
name
()
const
{
return
"miopen"
;
}
context
miopen_target
::
get_context
()
const
{
...
...
test/eval_test.cpp
View file @
583c76f2
...
...
@@ -2,6 +2,8 @@
#include <migraph/program.hpp>
#include <migraph/argument.hpp>
#include <migraph/shape.hpp>
#include <migraph/iterator_for.hpp>
#include <migraph/instruction.hpp>
#include <sstream>
#include "test.hpp"
...
...
@@ -68,7 +70,44 @@ struct minus_op
struct
id_target
{
std
::
string
name
()
const
{
return
"id"
;
}
void
apply
(
migraph
::
program
&
)
const
{}
std
::
vector
<
migraph
::
pass
>
get_passes
(
migraph
::
context
&
)
const
{
return
{};
}
migraph
::
context
get_context
()
const
{
return
{};
}
};
struct
reverse_pass
{
std
::
string
name
()
const
{
return
"reverse_pass"
;
}
void
apply
(
migraph
::
program
&
p
)
const
{
for
(
auto
ins
:
migraph
::
iterator_for
(
p
))
{
if
(
ins
->
op
.
name
()
==
"sum"
)
{
p
.
replace_instruction
(
ins
,
minus_op
{},
ins
->
arguments
);
}
else
if
(
ins
->
op
.
name
()
==
"minus"
)
{
p
.
replace_instruction
(
ins
,
sum_op
{},
ins
->
arguments
);
}
}
}
};
struct
reverse_target
{
std
::
string
name
()
const
{
return
"reverse"
;
}
std
::
vector
<
migraph
::
pass
>
get_passes
(
migraph
::
context
&
)
const
{
return
{
reverse_pass
{}};
}
migraph
::
context
get_context
()
const
{
return
{};
}
};
struct
double_reverse_target
{
std
::
string
name
()
const
{
return
"double_reverse"
;
}
std
::
vector
<
migraph
::
pass
>
get_passes
(
migraph
::
context
&
)
const
{
return
{
reverse_pass
{},
reverse_pass
{}};
}
migraph
::
context
get_context
()
const
{
return
{};
}
};
...
...
@@ -170,6 +209,32 @@ void target_test()
EXPECT
(
result
!=
migraph
::
literal
{
4
});
}
void
reverse_target_test
()
{
migraph
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
auto
two
=
p
.
add_literal
(
2
);
p
.
add_instruction
(
sum_op
{},
two
,
one
);
p
.
compile
(
reverse_target
{});
auto
result
=
p
.
eval
({});
EXPECT
(
result
==
migraph
::
literal
{
1
});
EXPECT
(
result
!=
migraph
::
literal
{
4
});
}
void
double_reverse_target_test
()
{
migraph
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
auto
two
=
p
.
add_literal
(
2
);
p
.
add_instruction
(
sum_op
{},
two
,
one
);
p
.
compile
(
double_reverse_target
{});
auto
result
=
p
.
eval
({});
EXPECT
(
result
==
migraph
::
literal
{
3
});
EXPECT
(
result
!=
migraph
::
literal
{
4
});
}
int
main
()
{
literal_test1
();
...
...
@@ -179,4 +244,5 @@ int main()
replace_test
();
insert_replace_test
();
target_test
();
reverse_target_test
();
}
tools/include/pass.hpp
0 → 100644
View file @
583c76f2
#ifndef MIGRAPH_GUARD_PASS_HPP
#define MIGRAPH_GUARD_PASS_HPP
#include <string>
#include <functional>
#include <memory>
#include <type_traits>
#include <utility>
namespace
migraph
{
struct
program
;
<%
interface
(
'
pass
'
,
virtual
(
'
name
'
,
returns
=
'
std
::
string
'
,
const
=
True
),
virtual
(
'
apply
'
,
returns
=
'
void
'
,
p
=
'
program
&
'
,
const
=
True
)
)
%>
}
// namespace migraph
#endif
tools/include/target.hpp
View file @
583c76f2
...
...
@@ -6,7 +6,9 @@
#include <memory>
#include <type_traits>
#include <utility>
#include <vector>
#include <migraph/context.hpp>
#include <migraph/pass.hpp>
namespace
migraph
{
...
...
@@ -15,7 +17,7 @@ struct program;
<%
interface
(
'
target
'
,
virtual
(
'
name
'
,
returns
=
'
std
::
string
'
,
const
=
True
),
virtual
(
'
apply
'
,
returns
=
'
void
'
,
p
=
'
program
&
'
,
const
=
True
),
virtual
(
'
get_passes
'
,
ctx
=
'
context
&
'
,
returns
=
'
std
::
vector
<
pass
>
'
,
const
=
True
),
virtual
(
'
get_context
'
,
returns
=
'
context
'
,
const
=
True
)
)
%>
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment