Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
a40d1b1d
Commit
a40d1b1d
authored
Nov 20, 2019
by
Paul
Browse files
Auto generate fallback code in type-erasure
parent
0045d0b7
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
216 additions
and
69 deletions
+216
-69
src/include/migraphx/operation.hpp
src/include/migraphx/operation.hpp
+128
-30
src/include/migraphx/target.hpp
src/include/migraphx/target.hpp
+47
-3
tools/include/operation.hpp
tools/include/operation.hpp
+17
-30
tools/te.py
tools/te.py
+24
-6
No files found.
src/include/migraphx/operation.hpp
View file @
a40d1b1d
...
...
@@ -62,7 +62,9 @@ bool has_finalize(const operation& x);
#else
namespace
operation_stream
{
namespace
detail
{
namespace
operation_operators
{
template
<
class
T
>
auto
operator
<<
(
std
::
ostream
&
os
,
const
T
&
x
)
->
decltype
(
os
<<
x
.
name
())
...
...
@@ -80,10 +82,6 @@ auto operator<<(std::ostream& os, const T& x) -> decltype(os << x.name())
return
os
;
}
}
// namespace operation_stream
namespace
operation_equal
{
template
<
class
T
,
class
U
>
auto
operator
==
(
const
T
&
x
,
const
U
&
y
)
->
decltype
(
x
.
name
()
==
y
.
name
())
{
...
...
@@ -95,7 +93,7 @@ auto operator==(const T& x, const U& y) -> decltype(x.name() == y.name())
return
reflect_tie
(
x
)
==
reflect_tie
(
yy
);
}
}
// namespace operation_
equal
}
// namespace operation_
operators
template
<
class
T
>
auto
compute_op
(
rank
<
2
>
,
...
...
@@ -177,24 +175,11 @@ auto is_context_free_op(const T& x) -> decltype(is_context_free_op(
}
template
<
class
T
>
std
::
ptrdiff_t
output_alias_op
(
rank
<
0
>
,
const
T
&
,
const
std
::
vector
<
shape
>&
)
std
::
ptrdiff_t
output_alias_op
(
const
T
&
,
const
std
::
vector
<
shape
>&
)
{
return
-
1
;
}
template
<
class
T
>
auto
output_alias_op
(
rank
<
1
>
,
const
T
&
x
,
const
std
::
vector
<
shape
>&
shapes
)
->
decltype
(
x
.
output_alias
(
shapes
))
{
return
x
.
output_alias
(
shapes
);
}
template
<
class
T
>
std
::
ptrdiff_t
output_alias_op
(
const
T
&
x
,
const
std
::
vector
<
shape
>&
shapes
)
{
return
output_alias_op
(
rank
<
1
>
{},
x
,
shapes
);
}
template
<
class
T
>
auto
finalize_op
(
rank
<
1
>
,
T
&
x
,
context
&
ctx
,
const
shape
&
output_shape
,
const
std
::
vector
<
shape
>&
input
)
...
...
@@ -233,6 +218,8 @@ auto has_finalize_op(const T&) -> decltype(has_finalize_op(rank<1>{},
return
{};
}
}
// namespace detail
/*
* Type-erased interface for:
*
...
...
@@ -396,6 +383,110 @@ struct operation
virtual
bool
operator
==
(
const
operation
&
y
)
const
=
0
;
};
template
<
class
T
>
static
auto
private_detail_te_default_is_context_free
(
char
,
T
&&
private_detail_te_self
)
->
decltype
(
private_detail_te_self
.
is_context_free
())
{
return
private_detail_te_self
.
is_context_free
();
}
template
<
class
T
>
static
bool
private_detail_te_default_is_context_free
(
float
,
T
&&
private_detail_te_self
)
{
return
detail
::
is_context_free_op
(
private_detail_te_self
);
}
template
<
class
T
>
static
auto
private_detail_te_default_has_finalize
(
char
,
T
&&
private_detail_te_self
)
->
decltype
(
private_detail_te_self
.
has_finalize
())
{
return
private_detail_te_self
.
has_finalize
();
}
template
<
class
T
>
static
bool
private_detail_te_default_has_finalize
(
float
,
T
&&
private_detail_te_self
)
{
return
detail
::
has_finalize_op
(
private_detail_te_self
);
}
template
<
class
T
>
static
auto
private_detail_te_default_output_alias
(
char
,
T
&&
private_detail_te_self
,
const
std
::
vector
<
shape
>&
input
)
->
decltype
(
private_detail_te_self
.
output_alias
(
input
))
{
return
private_detail_te_self
.
output_alias
(
input
);
}
template
<
class
T
>
static
std
::
ptrdiff_t
private_detail_te_default_output_alias
(
float
,
T
&&
private_detail_te_self
,
const
std
::
vector
<
shape
>&
input
)
{
return
detail
::
output_alias_op
(
private_detail_te_self
,
input
);
}
template
<
class
T
>
static
auto
private_detail_te_default_finalize
(
char
,
T
&&
private_detail_te_self
,
context
&
ctx
,
const
shape
&
output
,
const
std
::
vector
<
shape
>&
input
)
->
decltype
(
private_detail_te_self
.
finalize
(
ctx
,
output
,
input
))
{
private_detail_te_self
.
finalize
(
ctx
,
output
,
input
);
}
template
<
class
T
>
static
void
private_detail_te_default_finalize
(
float
,
T
&&
private_detail_te_self
,
context
&
ctx
,
const
shape
&
output
,
const
std
::
vector
<
shape
>&
input
)
{
detail
::
finalize_op
(
private_detail_te_self
,
ctx
,
output
,
input
);
}
template
<
class
T
>
static
auto
private_detail_te_default_compute
(
char
,
T
&&
private_detail_te_self
,
context
&
ctx
,
const
shape
&
output
,
const
std
::
vector
<
argument
>&
input
)
->
decltype
(
private_detail_te_self
.
compute
(
ctx
,
output
,
input
))
{
return
private_detail_te_self
.
compute
(
ctx
,
output
,
input
);
}
template
<
class
T
>
static
argument
private_detail_te_default_compute
(
float
,
T
&&
private_detail_te_self
,
context
&
ctx
,
const
shape
&
output
,
const
std
::
vector
<
argument
>&
input
)
{
return
detail
::
compute_op
(
private_detail_te_self
,
ctx
,
output
,
input
);
}
template
<
class
T
>
static
auto
private_detail_te_default_compute
(
char
,
T
&&
private_detail_te_self
,
const
shape
&
output
,
const
std
::
vector
<
argument
>&
input
)
->
decltype
(
private_detail_te_self
.
compute
(
output
,
input
))
{
return
private_detail_te_self
.
compute
(
output
,
input
);
}
template
<
class
T
>
static
argument
private_detail_te_default_compute
(
float
,
T
&&
private_detail_te_self
,
const
shape
&
output
,
const
std
::
vector
<
argument
>&
input
)
{
return
detail
::
compute_op
(
private_detail_te_self
,
output
,
input
);
}
template
<
typename
PrivateDetailTypeErasedT
>
struct
private_detail_te_handle_type
:
private_detail_te_handle_base_type
{
...
...
@@ -429,21 +520,26 @@ struct operation
bool
is_context_free
()
const
override
{
return
is_context_free
_op
(
private_detail_te_value
);
return
private_detail_te_default_
is_context_free
(
char
(
0
),
private_detail_te_value
);
}
bool
has_finalize
()
const
override
{
return
has_finalize_op
(
private_detail_te_value
);
}
bool
has_finalize
()
const
override
{
return
private_detail_te_default_has_finalize
(
char
(
0
),
private_detail_te_value
);
}
std
::
ptrdiff_t
output_alias
(
const
std
::
vector
<
shape
>&
input
)
const
override
{
return
output_alias
_op
(
private_detail_te_value
,
input
);
return
private_detail_te_default_
output_alias
(
char
(
0
),
private_detail_te_value
,
input
);
}
void
finalize
(
context
&
ctx
,
const
shape
&
output
,
const
std
::
vector
<
shape
>&
input
)
override
{
finalize_op
(
private_detail_te_value
,
ctx
,
output
,
input
);
private_detail_te_default_finalize
(
char
(
0
),
private_detail_te_value
,
ctx
,
output
,
input
);
}
shape
compute_shape
(
const
std
::
vector
<
shape
>&
input
)
const
override
...
...
@@ -457,24 +553,26 @@ struct operation
const
std
::
vector
<
argument
>&
input
)
const
override
{
return
compute_op
(
private_detail_te_value
,
ctx
,
output
,
input
);
return
private_detail_te_default_compute
(
char
(
0
),
private_detail_te_value
,
ctx
,
output
,
input
);
}
argument
compute
(
const
shape
&
output
,
const
std
::
vector
<
argument
>&
input
)
const
override
{
return
compute_op
(
private_detail_te_value
,
output
,
input
);
return
private_detail_te_default_compute
(
char
(
0
),
private_detail_te_value
,
output
,
input
);
}
std
::
ostream
&
operator_shift_left
(
std
::
ostream
&
os
)
const
override
{
using
migraphx
::
operation_
stream
::
operator
<<
;
using
migraphx
::
detail
::
operation_
operators
::
operator
<<
;
return
os
<<
private_detail_te_value
;
}
bool
operator
==
(
const
operation
&
y
)
const
override
{
using
migraphx
::
operation_
equal
::
operator
==
;
using
migraphx
::
detail
::
operation_
operators
::
operator
==
;
return
private_detail_te_value
==
y
;
}
...
...
@@ -550,7 +648,7 @@ inline bool is_context_free(const operation& op) { return op.is_context_free();
template
<
class
T
>
bool
is_context_free
(
const
T
&
x
)
{
return
is_context_free_op
(
x
);
return
detail
::
is_context_free_op
(
x
);
}
inline
bool
has_finalize
(
const
operation
&
op
)
{
return
op
.
has_finalize
();
}
...
...
@@ -558,7 +656,7 @@ inline bool has_finalize(const operation& op) { return op.has_finalize(); }
template
<
class
T
>
bool
has_finalize
(
const
T
&
x
)
{
return
has_finalize_op
(
x
);
return
detail
::
has_finalize_op
(
x
);
}
#endif
...
...
src/include/migraphx/target.hpp
View file @
a40d1b1d
...
...
@@ -248,6 +248,50 @@ struct target
virtual
argument
allocate
(
const
shape
&
s
)
const
=
0
;
};
template
<
class
T
>
static
auto
private_detail_te_default_copy_to
(
char
,
T
&&
private_detail_te_self
,
const
argument
&
input
)
->
decltype
(
private_detail_te_self
.
copy_to
(
input
))
{
return
private_detail_te_self
.
copy_to
(
input
);
}
template
<
class
T
>
static
argument
private_detail_te_default_copy_to
(
float
,
T
&&
private_detail_te_self
,
const
argument
&
input
)
{
return
copy_to_target
(
private_detail_te_self
,
input
);
}
template
<
class
T
>
static
auto
private_detail_te_default_copy_from
(
char
,
T
&&
private_detail_te_self
,
const
argument
&
input
)
->
decltype
(
private_detail_te_self
.
copy_from
(
input
))
{
return
private_detail_te_self
.
copy_from
(
input
);
}
template
<
class
T
>
static
argument
private_detail_te_default_copy_from
(
float
,
T
&&
private_detail_te_self
,
const
argument
&
input
)
{
return
copy_from_target
(
private_detail_te_self
,
input
);
}
template
<
class
T
>
static
auto
private_detail_te_default_allocate
(
char
,
T
&&
private_detail_te_self
,
const
shape
&
s
)
->
decltype
(
private_detail_te_self
.
allocate
(
s
))
{
return
private_detail_te_self
.
allocate
(
s
);
}
template
<
class
T
>
static
argument
private_detail_te_default_allocate
(
float
,
T
&&
private_detail_te_self
,
const
shape
&
s
)
{
return
target_allocate
(
private_detail_te_self
,
s
);
}
template
<
typename
PrivateDetailTypeErasedT
>
struct
private_detail_te_handle_type
:
private_detail_te_handle_base_type
{
...
...
@@ -289,19 +333,19 @@ struct target
argument
copy_to
(
const
argument
&
input
)
const
override
{
return
copy_to_target
(
private_detail_te_value
,
input
);
return
private_detail_te_default_copy_to
(
char
(
0
),
private_detail_te_value
,
input
);
}
argument
copy_from
(
const
argument
&
input
)
const
override
{
return
copy_from_target
(
private_detail_te_value
,
input
);
return
private_detail_te_default_copy_from
(
char
(
0
),
private_detail_te_value
,
input
);
}
argument
allocate
(
const
shape
&
s
)
const
override
{
return
targe
t_allocate
(
private_detail_te_value
,
s
);
return
private_detail_te_defaul
t_allocate
(
char
(
0
),
private_detail_te_value
,
s
);
}
PrivateDetailTypeErasedT
private_detail_te_value
;
...
...
tools/include/operation.hpp
View file @
a40d1b1d
...
...
@@ -62,7 +62,9 @@ bool has_finalize(const operation& x);
#else
namespace
operation_stream
{
namespace
detail
{
namespace
operation_operators
{
template
<
class
T
>
auto
operator
<<
(
std
::
ostream
&
os
,
const
T
&
x
)
->
decltype
(
os
<<
x
.
name
())
...
...
@@ -80,10 +82,6 @@ auto operator<<(std::ostream& os, const T& x) -> decltype(os << x.name())
return
os
;
}
}
// namespace operation_stream
namespace
operation_equal
{
template
<
class
T
,
class
U
>
auto
operator
==
(
const
T
&
x
,
const
U
&
y
)
->
decltype
(
x
.
name
()
==
y
.
name
())
{
...
...
@@ -95,7 +93,7 @@ auto operator==(const T& x, const U& y) -> decltype(x.name() == y.name())
return
reflect_tie
(
x
)
==
reflect_tie
(
yy
);
}
}
// namespace operation_
equal
}
// namespace operation_
operators
template
<
class
T
>
auto
compute_op
(
rank
<
2
>
,
...
...
@@ -177,24 +175,11 @@ auto is_context_free_op(const T& x) -> decltype(is_context_free_op(
}
template
<
class
T
>
std
::
ptrdiff_t
output_alias_op
(
rank
<
0
>
,
const
T
&
,
const
std
::
vector
<
shape
>&
)
std
::
ptrdiff_t
output_alias_op
(
const
T
&
,
const
std
::
vector
<
shape
>&
)
{
return
-
1
;
}
template
<
class
T
>
auto
output_alias_op
(
rank
<
1
>
,
const
T
&
x
,
const
std
::
vector
<
shape
>&
shapes
)
->
decltype
(
x
.
output_alias
(
shapes
))
{
return
x
.
output_alias
(
shapes
);
}
template
<
class
T
>
std
::
ptrdiff_t
output_alias_op
(
const
T
&
x
,
const
std
::
vector
<
shape
>&
shapes
)
{
return
output_alias_op
(
rank
<
1
>
{},
x
,
shapes
);
}
template
<
class
T
>
auto
finalize_op
(
rank
<
1
>
,
T
&
x
,
context
&
ctx
,
const
shape
&
output_shape
,
const
std
::
vector
<
shape
>&
input
)
...
...
@@ -233,22 +218,24 @@ auto has_finalize_op(const T&) -> decltype(has_finalize_op(rank<1>{},
return
{};
}
}
// namespace detail
<%
interface
(
'
operation
'
,
virtual
(
'
name
'
,
returns
=
'
std
::
string
'
,
const
=
True
),
virtual
(
'
is_context_free
'
,
returns
=
'
bool
'
,
const
=
True
,
default
=
'
is_context_free_op
'
),
virtual
(
'
has_finalize
'
,
returns
=
'
bool
'
,
const
=
True
,
default
=
'
has_finalize_op
'
),
virtual
(
'
is_context_free
'
,
returns
=
'
bool
'
,
const
=
True
,
default
=
'
detail
::
is_context_free_op
'
),
virtual
(
'
has_finalize
'
,
returns
=
'
bool
'
,
const
=
True
,
default
=
'
detail
::
has_finalize_op
'
),
virtual
(
'
output_alias
'
,
returns
=
'
std
::
ptrdiff_t
'
,
input
=
'
const
std
::
vector
<
shape
>&
'
,
const
=
True
,
default
=
'
output_alias_op
'
),
default
=
'
detail
::
output_alias_op
'
),
virtual
(
'
finalize
'
,
ctx
=
'
context
&
'
,
output
=
'
const
shape
&
'
,
input
=
'
const
std
::
vector
<
shape
>&
'
,
default
=
'
finalize_op
'
),
default
=
'
detail
::
finalize_op
'
),
virtual
(
'
compute_shape
'
,
returns
=
'
shape
'
,
input
=
'
const
std
::
vector
<
shape
>&
'
,
const
=
True
),
virtual
(
'
compute
'
,
returns
=
'
argument
'
,
...
...
@@ -256,23 +243,23 @@ auto has_finalize_op(const T&) -> decltype(has_finalize_op(rank<1>{},
output
=
'
const
shape
&
'
,
input
=
'
const
std
::
vector
<
argument
>&
'
,
const
=
True
,
default
=
'
compute_op
'
),
default
=
'
detail
::
compute_op
'
),
virtual
(
'
compute
'
,
returns
=
'
argument
'
,
output
=
'
const
shape
&
'
,
input
=
'
const
std
::
vector
<
argument
>&
'
,
const
=
True
,
default
=
'
compute_op
'
),
default
=
'
detail
::
compute_op
'
),
friend
(
'
operator
<<
'
,
returns
=
'
std
::
ostream
&
'
,
os
=
'
std
::
ostream
&
'
,
op
=
'
const
operation
&
'
,
using
=
'
migraphx
::
operation_
stream
::
operator
<<
'
),
using
=
'
migraphx
::
detail
::
operation_
operators
::
operator
<<
'
),
friend
(
'
operator
==
'
,
returns
=
'
bool
'
,
x
=
'
const
operation
&
'
,
y
=
'
const
operation
&
'
,
using
=
'
migraphx
::
operation_
equal
::
operator
==
'
))
%>
using
=
'
migraphx
::
detail
::
operation_
operators
::
operator
==
'
))
%>
inline
bool
operator
!=
(
const
operation
&
x
,
const
operation
&
y
)
{
...
...
@@ -284,7 +271,7 @@ inline bool is_context_free(const operation& op) { return op.is_context_free();
template
<
class
T
>
bool
is_context_free
(
const
T
&
x
)
{
return
is_context_free_op
(
x
);
return
detail
::
is_context_free_op
(
x
);
}
inline
bool
has_finalize
(
const
operation
&
op
)
{
return
op
.
has_finalize
();
}
...
...
@@ -292,7 +279,7 @@ inline bool has_finalize(const operation& op) { return op.has_finalize(); }
template
<
class
T
>
bool
has_finalize
(
const
T
&
x
)
{
return
has_finalize_op
(
x
);
return
detail
::
has_finalize_op
(
x
);
}
#endif
...
...
tools/te.py
View file @
a40d1b1d
...
...
@@ -88,6 +88,8 @@ private:
${pure_virtual_members}
};
${default_members}
template <typename PrivateDetailTypeErasedT>
struct private_detail_te_handle_type :
private_detail_te_handle_base_type
...
...
@@ -205,6 +207,21 @@ ${return_type} ${internal_name}(${member_params}) ${member_const} override
comment_member
=
string
.
Template
(
'''* ${friend} ${return_type} ${name}(${params}) ${const};'''
)
default_member
=
string
.
Template
(
'''
template<class T>
static auto private_detail_te_default_${name}(char, T&& private_detail_te_self ${comma} ${member_params})
-> decltype(private_detail_te_self.${name}(${args}))
{
${return_} private_detail_te_self.${name}(${args});
}
template<class T>
static ${return_type} private_detail_te_default_${internal_name}(float, T&& private_detail_te_self ${comma} ${member_params})
{
${return_} ${default}(private_detail_te_self ${comma} ${args});
}
'''
)
def
trim_type_name
(
name
):
n
=
name
.
strip
()
...
...
@@ -237,12 +254,8 @@ def generate_call(m, friend, indirect):
if
friend
:
return
string
.
Template
(
'${name}(${args})'
).
substitute
(
m
)
if
indirect
:
if
m
[
'args'
]:
return
string
.
Template
(
'${default}(private_detail_te_value, ${args})'
).
substitute
(
m
)
else
:
return
string
.
Template
(
'${default}(
private_detail_te_value)'
).
substitute
(
m
)
'private_detail_te_default_${internal_name}(char(0),
private_detail_te_value
${comma} ${args}
)'
).
substitute
(
m
)
return
string
.
Template
(
'private_detail_te_value.${name}(${args})'
).
substitute
(
m
)
...
...
@@ -314,6 +327,7 @@ def convert_member(d, struct_name):
member
[
'params'
]
=
','
.
join
(
params
)
member
[
'params'
]
=
','
.
join
(
params
)
member
[
'member_params'
]
=
','
.
join
(
member_params
)
member
[
'comma'
]
=
','
if
len
(
args
)
>
0
else
''
member
[
'call'
]
=
generate_call
(
member
,
friend
,
indirect
)
return
member
return
None
...
...
@@ -324,15 +338,19 @@ def generate_form(name, members):
pure_virtual_members
=
[]
virtual_members
=
[]
comment_members
=
[]
default_members
=
[]
for
member
in
members
:
m
=
convert_member
(
member
,
name
)
nonvirtual_members
.
append
(
nonvirtual_member
.
substitute
(
m
))
pure_virtual_members
.
append
(
pure_virtual_member
.
substitute
(
m
))
virtual_members
.
append
(
virtual_member
.
substitute
(
m
))
comment_members
.
append
(
comment_member
.
substitute
(
m
))
if
'default'
in
m
:
default_members
.
append
(
default_member
.
substitute
(
m
))
return
form
.
substitute
(
nonvirtual_members
=
''
.
join
(
nonvirtual_members
),
pure_virtual_members
=
''
.
join
(
pure_virtual_members
),
virtual_members
=
''
.
join
(
virtual_members
),
default_members
=
''
.
join
(
default_members
),
comment_members
=
'
\n
'
.
join
(
comment_members
),
struct_name
=
name
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment