Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
a9e5d73c
Commit
a9e5d73c
authored
Jul 06, 2018
by
Paul
Browse files
Update get_passes for the cpu
parent
cff16121
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
273 additions
and
12 deletions
+273
-12
src/include/migraph/pass.hpp
src/include/migraph/pass.hpp
+204
-0
src/include/migraph/target.hpp
src/include/migraph/target.hpp
+13
-7
src/program.cpp
src/program.cpp
+8
-1
src/targets/cpu/cpu_target.cpp
src/targets/cpu/cpu_target.cpp
+17
-1
src/targets/cpu/include/migraph/cpu/cpu_target.hpp
src/targets/cpu/include/migraph/cpu/cpu_target.hpp
+1
-1
test/eval_test.cpp
test/eval_test.cpp
+4
-1
tools/include/pass.hpp
tools/include/pass.hpp
+23
-0
tools/include/target.hpp
tools/include/target.hpp
+3
-1
No files found.
src/include/migraph/pass.hpp
0 → 100644
View file @
a9e5d73c
#ifndef MIGRAPH_GUARD_PASS_HPP
#define MIGRAPH_GUARD_PASS_HPP
#include <string>
#include <functional>
#include <memory>
#include <type_traits>
#include <utility>
namespace
migraph
{
struct
program
;
/*
* Type-erased interface for:
*
* struct pass
* {
* std::string name() const;
* void apply(program & p) const;
* };
*
*/
struct
pass
{
// Constructors
pass
()
=
default
;
template
<
typename
PrivateDetailTypeErasedT
>
pass
(
PrivateDetailTypeErasedT
value
)
:
private_detail_te_handle_mem_var
(
std
::
make_shared
<
private_detail_te_handle_type
<
typename
std
::
remove_reference
<
PrivateDetailTypeErasedT
>::
type
>>
(
std
::
forward
<
PrivateDetailTypeErasedT
>
(
value
)))
{
}
// Assignment
template
<
typename
PrivateDetailTypeErasedT
>
pass
&
operator
=
(
PrivateDetailTypeErasedT
value
)
{
if
(
private_detail_te_handle_mem_var
.
unique
())
*
private_detail_te_handle_mem_var
=
std
::
forward
<
PrivateDetailTypeErasedT
>
(
value
);
else
if
(
!
private_detail_te_handle_mem_var
)
private_detail_te_handle_mem_var
=
std
::
make_shared
<
PrivateDetailTypeErasedT
>
(
std
::
forward
<
PrivateDetailTypeErasedT
>
(
value
));
return
*
this
;
}
// Cast
template
<
typename
PrivateDetailTypeErasedT
>
PrivateDetailTypeErasedT
*
any_cast
()
{
return
private_detail_te_get_handle
().
type
()
==
typeid
(
PrivateDetailTypeErasedT
)
?
std
::
addressof
(
static_cast
<
private_detail_te_handle_type
<
typename
std
::
remove_cv
<
PrivateDetailTypeErasedT
>::
type
>&>
(
private_detail_te_get_handle
())
.
private_detail_te_value
)
:
nullptr
;
}
template
<
typename
PrivateDetailTypeErasedT
>
const
typename
std
::
remove_cv
<
PrivateDetailTypeErasedT
>::
type
*
any_cast
()
const
{
return
private_detail_te_get_handle
().
type
()
==
typeid
(
PrivateDetailTypeErasedT
)
?
std
::
addressof
(
static_cast
<
const
private_detail_te_handle_type
<
typename
std
::
remove_cv
<
PrivateDetailTypeErasedT
>::
type
>&>
(
private_detail_te_get_handle
())
.
private_detail_te_value
)
:
nullptr
;
}
const
std
::
type_info
&
type_id
()
const
{
if
(
private_detail_te_handle_empty
())
return
typeid
(
std
::
nullptr_t
);
else
return
private_detail_te_get_handle
().
type
();
}
std
::
string
name
()
const
{
assert
((
*
this
).
private_detail_te_handle_mem_var
);
return
(
*
this
).
private_detail_te_get_handle
().
name
();
}
void
apply
(
program
&
p
)
const
{
assert
((
*
this
).
private_detail_te_handle_mem_var
);
return
(
*
this
).
private_detail_te_get_handle
().
apply
(
p
);
}
private:
struct
private_detail_te_handle_base_type
{
virtual
~
private_detail_te_handle_base_type
()
{}
virtual
std
::
shared_ptr
<
private_detail_te_handle_base_type
>
clone
()
const
=
0
;
virtual
const
std
::
type_info
&
type
()
const
=
0
;
virtual
std
::
string
name
()
const
=
0
;
virtual
void
apply
(
program
&
p
)
const
=
0
;
};
template
<
typename
PrivateDetailTypeErasedT
>
struct
private_detail_te_handle_type
:
private_detail_te_handle_base_type
{
template
<
typename
PrivateDetailTypeErasedU
=
PrivateDetailTypeErasedT
>
private_detail_te_handle_type
(
PrivateDetailTypeErasedT
value
,
typename
std
::
enable_if
<
std
::
is_reference
<
PrivateDetailTypeErasedU
>::
value
>::
type
*
=
nullptr
)
:
private_detail_te_value
(
value
)
{
}
template
<
typename
PrivateDetailTypeErasedU
=
PrivateDetailTypeErasedT
>
private_detail_te_handle_type
(
PrivateDetailTypeErasedT
value
,
typename
std
::
enable_if
<!
std
::
is_reference
<
PrivateDetailTypeErasedU
>::
value
,
int
>::
type
*
=
nullptr
)
noexcept
:
private_detail_te_value
(
std
::
move
(
value
))
{
}
std
::
shared_ptr
<
private_detail_te_handle_base_type
>
clone
()
const
override
{
return
std
::
make_shared
<
private_detail_te_handle_type
>
(
private_detail_te_value
);
}
const
std
::
type_info
&
type
()
const
override
{
return
typeid
(
private_detail_te_value
);
}
std
::
string
name
()
const
override
{
return
private_detail_te_value
.
name
();
}
void
apply
(
program
&
p
)
const
override
{
return
private_detail_te_value
.
apply
(
p
);
}
PrivateDetailTypeErasedT
private_detail_te_value
;
};
template
<
typename
PrivateDetailTypeErasedT
>
struct
private_detail_te_handle_type
<
std
::
reference_wrapper
<
PrivateDetailTypeErasedT
>>
:
private_detail_te_handle_type
<
PrivateDetailTypeErasedT
&>
{
private_detail_te_handle_type
(
std
::
reference_wrapper
<
PrivateDetailTypeErasedT
>
ref
)
:
private_detail_te_handle_type
<
PrivateDetailTypeErasedT
&>
(
ref
.
get
())
{
}
};
bool
private_detail_te_handle_empty
()
const
{
return
private_detail_te_handle_mem_var
==
nullptr
;
}
const
private_detail_te_handle_base_type
&
private_detail_te_get_handle
()
const
{
assert
(
private_detail_te_handle_mem_var
!=
nullptr
);
return
*
private_detail_te_handle_mem_var
;
}
private_detail_te_handle_base_type
&
private_detail_te_get_handle
()
{
assert
(
private_detail_te_handle_mem_var
!=
nullptr
);
if
(
!
private_detail_te_handle_mem_var
.
unique
())
private_detail_te_handle_mem_var
=
private_detail_te_handle_mem_var
->
clone
();
return
*
private_detail_te_handle_mem_var
;
}
std
::
shared_ptr
<
private_detail_te_handle_base_type
>
private_detail_te_handle_mem_var
;
};
template
<
typename
ValueType
>
inline
const
ValueType
*
any_cast
(
const
pass
*
x
)
{
return
x
->
any_cast
<
ValueType
>
();
}
template
<
typename
ValueType
>
inline
ValueType
*
any_cast
(
pass
*
x
)
{
return
x
->
any_cast
<
ValueType
>
();
}
template
<
typename
ValueType
>
inline
ValueType
&
any_cast
(
pass
&
x
)
{
auto
*
y
=
x
.
any_cast
<
typename
std
::
remove_reference
<
ValueType
>::
type
>
();
if
(
y
==
nullptr
)
throw
std
::
bad_cast
();
return
*
y
;
}
template
<
typename
ValueType
>
inline
const
ValueType
&
any_cast
(
const
pass
&
x
)
{
const
auto
*
y
=
x
.
any_cast
<
typename
std
::
remove_reference
<
ValueType
>::
type
>
();
if
(
y
==
nullptr
)
throw
std
::
bad_cast
();
return
*
y
;
}
}
// namespace migraph
#endif
src/include/migraph/target.hpp
View file @
a9e5d73c
...
@@ -6,7 +6,9 @@
...
@@ -6,7 +6,9 @@
#include <memory>
#include <memory>
#include <type_traits>
#include <type_traits>
#include <utility>
#include <utility>
#include <vector>
#include <migraph/context.hpp>
#include <migraph/context.hpp>
#include <migraph/pass.hpp>
namespace
migraph
{
namespace
migraph
{
...
@@ -18,7 +20,7 @@ struct program;
...
@@ -18,7 +20,7 @@ struct program;
* struct target
* struct target
* {
* {
* std::string name() const;
* std::string name() const;
*
void apply(program & p
) const;
*
std::vector<pass> get_passes(context& ctx
) const;
* context get_context() const;
* context get_context() const;
* };
* };
*
*
...
@@ -87,10 +89,10 @@ struct target
...
@@ -87,10 +89,10 @@ struct target
return
(
*
this
).
private_detail_te_get_handle
().
name
();
return
(
*
this
).
private_detail_te_get_handle
().
name
();
}
}
void
apply
(
program
&
p
)
const
std
::
vector
<
pass
>
get_passes
(
context
&
ctx
)
const
{
{
assert
((
*
this
).
private_detail_te_handle_mem_var
);
assert
((
*
this
).
private_detail_te_handle_mem_var
);
return
(
*
this
).
private_detail_te_get_handle
().
apply
(
p
);
return
(
*
this
).
private_detail_te_get_handle
().
get_passes
(
ctx
);
}
}
context
get_context
()
const
context
get_context
()
const
...
@@ -106,9 +108,9 @@ struct target
...
@@ -106,9 +108,9 @@ struct target
virtual
std
::
shared_ptr
<
private_detail_te_handle_base_type
>
clone
()
const
=
0
;
virtual
std
::
shared_ptr
<
private_detail_te_handle_base_type
>
clone
()
const
=
0
;
virtual
const
std
::
type_info
&
type
()
const
=
0
;
virtual
const
std
::
type_info
&
type
()
const
=
0
;
virtual
std
::
string
name
()
const
=
0
;
virtual
std
::
string
name
()
const
=
0
;
virtual
void
apply
(
program
&
p
)
const
=
0
;
virtual
std
::
vector
<
pass
>
get_passes
(
context
&
ctx
)
const
=
0
;
virtual
context
get_context
()
const
=
0
;
virtual
context
get_context
()
const
=
0
;
};
};
template
<
typename
PrivateDetailTypeErasedT
>
template
<
typename
PrivateDetailTypeErasedT
>
...
@@ -141,7 +143,11 @@ struct target
...
@@ -141,7 +143,11 @@ struct target
std
::
string
name
()
const
override
{
return
private_detail_te_value
.
name
();
}
std
::
string
name
()
const
override
{
return
private_detail_te_value
.
name
();
}
void
apply
(
program
&
p
)
const
override
{
return
private_detail_te_value
.
apply
(
p
);
}
std
::
vector
<
pass
>
get_passes
(
context
&
ctx
)
const
override
{
return
private_detail_te_value
.
get_passes
(
ctx
);
}
context
get_context
()
const
override
{
return
private_detail_te_value
.
get_context
();
}
context
get_context
()
const
override
{
return
private_detail_te_value
.
get_context
();
}
...
...
src/program.cpp
View file @
a9e5d73c
...
@@ -111,7 +111,14 @@ void program::compile(const target& t)
...
@@ -111,7 +111,14 @@ void program::compile(const target& t)
{
{
assert
(
this
->
validate
()
!=
impl
->
instructions
.
end
());
assert
(
this
->
validate
()
!=
impl
->
instructions
.
end
());
this
->
impl
->
ctx
=
t
.
get_context
();
this
->
impl
->
ctx
=
t
.
get_context
();
t
.
apply
(
*
this
);
for
(
auto
&&
p
:
t
.
get_passes
(
this
->
impl
->
ctx
))
{
p
.
apply
(
*
this
);
#ifndef NDEBUG
if
(
this
->
validate
()
==
impl
->
instructions
.
end
())
MIGRAPH_THROW
(
p
.
name
()
+
" pass produces invalid program"
);
#endif
}
if
(
this
->
validate
()
==
impl
->
instructions
.
end
())
if
(
this
->
validate
()
==
impl
->
instructions
.
end
())
MIGRAPH_THROW
(
"Invalid program from compilation"
);
MIGRAPH_THROW
(
"Invalid program from compilation"
);
}
}
...
...
src/targets/cpu/cpu_target.cpp
View file @
a9e5d73c
...
@@ -538,9 +538,25 @@ struct cpu_apply
...
@@ -538,9 +538,25 @@ struct cpu_apply
}
}
};
};
struct
cpu_pass
{
std
::
string
name
()
const
{
return
"cpu::pass"
;
}
void
apply
(
program
&
p
)
const
{
cpu_apply
{
&
p
}.
apply
();
}
};
std
::
string
cpu_target
::
name
()
const
{
return
"cpu"
;
}
std
::
string
cpu_target
::
name
()
const
{
return
"cpu"
;
}
void
cpu_target
::
apply
(
program
&
p
)
const
{
cpu_apply
{
&
p
}.
apply
();
}
std
::
vector
<
pass
>
cpu_target
::
get_passes
(
context
&
)
const
{
return
{
cpu_pass
{}
};
}
}
// namespace cpu
}
// namespace cpu
...
...
src/targets/cpu/include/migraph/cpu/cpu_target.hpp
View file @
a9e5d73c
...
@@ -9,7 +9,7 @@ namespace cpu {
...
@@ -9,7 +9,7 @@ namespace cpu {
struct
cpu_target
struct
cpu_target
{
{
std
::
string
name
()
const
;
std
::
string
name
()
const
;
void
apply
(
program
&
p
)
const
;
std
::
vector
<
pass
>
get_passes
(
context
&
ctx
)
const
;
context
get_context
()
const
{
return
{};
}
context
get_context
()
const
{
return
{};
}
};
};
...
...
test/eval_test.cpp
View file @
a9e5d73c
...
@@ -68,7 +68,10 @@ struct minus_op
...
@@ -68,7 +68,10 @@ struct minus_op
struct
id_target
struct
id_target
{
{
std
::
string
name
()
const
{
return
"id"
;
}
std
::
string
name
()
const
{
return
"id"
;
}
void
apply
(
migraph
::
program
&
)
const
{}
std
::
vector
<
migraph
::
pass
>
get_passes
(
migraph
::
context
&
)
const
{
return
{};
}
migraph
::
context
get_context
()
const
{
return
{};
}
migraph
::
context
get_context
()
const
{
return
{};
}
};
};
...
...
tools/include/pass.hpp
0 → 100644
View file @
a9e5d73c
#ifndef MIGRAPH_GUARD_PASS_HPP
#define MIGRAPH_GUARD_PASS_HPP
#include <string>
#include <functional>
#include <memory>
#include <type_traits>
#include <utility>
namespace
migraph
{
struct
program
;
<%
interface
(
'
pass
'
,
virtual
(
'
name
'
,
returns
=
'
std
::
string
'
,
const
=
True
),
virtual
(
'
apply
'
,
returns
=
'
void
'
,
p
=
'
program
&
'
,
const
=
True
)
)
%>
}
// namespace migraph
#endif
tools/include/target.hpp
View file @
a9e5d73c
...
@@ -6,7 +6,9 @@
...
@@ -6,7 +6,9 @@
#include <memory>
#include <memory>
#include <type_traits>
#include <type_traits>
#include <utility>
#include <utility>
#include <vector>
#include <migraph/context.hpp>
#include <migraph/context.hpp>
#include <migraph/pass.hpp>
namespace
migraph
{
namespace
migraph
{
...
@@ -15,7 +17,7 @@ struct program;
...
@@ -15,7 +17,7 @@ struct program;
<%
<%
interface
(
'
target
'
,
interface
(
'
target
'
,
virtual
(
'
name
'
,
returns
=
'
std
::
string
'
,
const
=
True
),
virtual
(
'
name
'
,
returns
=
'
std
::
string
'
,
const
=
True
),
virtual
(
'
apply
'
,
returns
=
'
void
'
,
p
=
'
program
&
'
,
const
=
True
),
virtual
(
'
get_passes
'
,
ctx
=
'
context
&
'
,
returns
=
'
std
::
vector
<
pass
>
'
,
const
=
True
),
virtual
(
'
get_context
'
,
returns
=
'
context
'
,
const
=
True
)
virtual
(
'
get_context
'
,
returns
=
'
context
'
,
const
=
True
)
)
)
%>
%>
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment