Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
330fe429
Commit
330fe429
authored
May 17, 2018
by
Paul
Browse files
Merge branch 'friend-op'
parents
f54dcb28
8dc320a6
Changes
9
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
184 additions
and
32 deletions
+184
-32
Jenkinsfile
Jenkinsfile
+1
-1
frontend/onnx/read_onnx.cpp
frontend/onnx/read_onnx.cpp
+5
-0
include/rtg/builtin.hpp
include/rtg/builtin.hpp
+10
-0
include/rtg/operation.hpp
include/rtg/operation.hpp
+38
-12
include/rtg/operators.hpp
include/rtg/operators.hpp
+23
-0
include/rtg/target.hpp
include/rtg/target.hpp
+6
-6
test/operation.cpp
test/operation.cpp
+5
-0
tools/include/operation.hpp
tools/include/operation.hpp
+12
-1
tools/te.py
tools/te.py
+84
-12
No files found.
Jenkinsfile
View file @
330fe429
...
@@ -7,7 +7,7 @@ def rocmtestnode(variant, name, body) {
...
@@ -7,7 +7,7 @@ def rocmtestnode(variant, name, body) {
mkdir build
mkdir build
cd build
cd build
CXX=${compiler} CXXFLAGS='-Werror' cmake -DCMAKE_CXX_FLAGS_DEBUG='-g -fno-omit-frame-pointer -fsanitize=undefined -fno-sanitize-recover=undefined' ${flags} ..
CXX=${compiler} CXXFLAGS='-Werror' cmake -DCMAKE_CXX_FLAGS_DEBUG='-g -fno-omit-frame-pointer -fsanitize=undefined -fno-sanitize-recover=undefined' ${flags} ..
CTEST_PARALLEL_LEVEL=32 make -j32 check
CTEST_PARALLEL_LEVEL=32 make -j32
all doc
check
"""
"""
echo
cmd
echo
cmd
sh
cmd
sh
cmd
...
...
frontend/onnx/read_onnx.cpp
View file @
330fe429
...
@@ -23,6 +23,11 @@ struct unknown
...
@@ -23,6 +23,11 @@ struct unknown
return
input
.
front
();
return
input
.
front
();
}
}
rtg
::
argument
compute
(
std
::
vector
<
rtg
::
argument
>
)
const
{
RTG_THROW
(
"not computable"
);
}
rtg
::
argument
compute
(
std
::
vector
<
rtg
::
argument
>
)
const
{
RTG_THROW
(
"not computable"
);
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
unknown
&
x
)
{
os
<<
x
.
name
();
return
os
;
}
};
};
template
<
class
C
,
class
T
>
template
<
class
C
,
class
T
>
...
...
include/rtg/builtin.hpp
View file @
330fe429
...
@@ -13,6 +13,11 @@ struct literal
...
@@ -13,6 +13,11 @@ struct literal
std
::
string
name
()
const
{
return
"@literal"
;
}
std
::
string
name
()
const
{
return
"@literal"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
)
const
{
RTG_THROW
(
"builtin"
);
}
shape
compute_shape
(
std
::
vector
<
shape
>
)
const
{
RTG_THROW
(
"builtin"
);
}
argument
compute
(
std
::
vector
<
argument
>
)
const
{
RTG_THROW
(
"builtin"
);
}
argument
compute
(
std
::
vector
<
argument
>
)
const
{
RTG_THROW
(
"builtin"
);
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
literal
&
op
)
{
os
<<
op
.
name
();
return
os
;
}
};
};
struct
param
struct
param
...
@@ -21,6 +26,11 @@ struct param
...
@@ -21,6 +26,11 @@ struct param
std
::
string
name
()
const
{
return
"@param:"
+
parameter
;
}
std
::
string
name
()
const
{
return
"@param:"
+
parameter
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
)
const
{
RTG_THROW
(
"builtin"
);
}
shape
compute_shape
(
std
::
vector
<
shape
>
)
const
{
RTG_THROW
(
"builtin"
);
}
argument
compute
(
std
::
vector
<
argument
>
)
const
{
RTG_THROW
(
"builtin"
);
}
argument
compute
(
std
::
vector
<
argument
>
)
const
{
RTG_THROW
(
"builtin"
);
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
param
&
op
)
{
os
<<
op
.
name
();
return
os
;
}
};
};
}
// namespace builtin
}
// namespace builtin
...
...
include/rtg/operation.hpp
View file @
330fe429
...
@@ -11,6 +11,16 @@
...
@@ -11,6 +11,16 @@
namespace
rtg
{
namespace
rtg
{
namespace
operation_stream
{
template
<
class
T
>
auto
operator
<<
(
std
::
ostream
&
os
,
const
T
&
x
)
->
decltype
(
os
<<
x
.
name
())
{
return
os
<<
x
.
name
();
}
}
// namespace operation_stream
/*
/*
* Type-erased interface for:
* Type-erased interface for:
*
*
...
@@ -19,6 +29,7 @@ namespace rtg {
...
@@ -19,6 +29,7 @@ namespace rtg {
* std::string name() const;
* std::string name() const;
* shape compute_shape(std::vector<shape> input) const;
* shape compute_shape(std::vector<shape> input) const;
* argument compute(std::vector<argument> input) const;
* argument compute(std::vector<argument> input) const;
* friend std::ostream & operator<<(std::ostream & os,const operation & op) ;
* };
* };
*
*
*/
*/
...
@@ -74,20 +85,26 @@ struct operation
...
@@ -74,20 +85,26 @@ struct operation
std
::
string
name
()
const
std
::
string
name
()
const
{
{
assert
(
private_detail_te_handle_mem_var
);
assert
(
(
*
this
).
private_detail_te_handle_mem_var
);
return
private_detail_te_get_handle
().
name
();
return
(
*
this
).
private_detail_te_get_handle
().
name
();
}
}
shape
compute_shape
(
std
::
vector
<
shape
>
input
)
const
shape
compute_shape
(
std
::
vector
<
shape
>
input
)
const
{
{
assert
(
private_detail_te_handle_mem_var
);
assert
(
(
*
this
).
private_detail_te_handle_mem_var
);
return
private_detail_te_get_handle
().
compute_shape
(
std
::
move
(
input
));
return
(
*
this
).
private_detail_te_get_handle
().
compute_shape
(
std
::
move
(
input
));
}
}
argument
compute
(
std
::
vector
<
argument
>
input
)
const
argument
compute
(
std
::
vector
<
argument
>
input
)
const
{
{
assert
(
private_detail_te_handle_mem_var
);
assert
((
*
this
).
private_detail_te_handle_mem_var
);
return
private_detail_te_get_handle
().
compute
(
std
::
move
(
input
));
return
(
*
this
).
private_detail_te_get_handle
().
compute
(
std
::
move
(
input
));
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
operation
&
op
)
{
assert
(
op
.
private_detail_te_handle_mem_var
);
return
op
.
private_detail_te_get_handle
().
operator_shift_left
(
os
);
}
}
private:
private:
...
@@ -100,6 +117,7 @@ struct operation
...
@@ -100,6 +117,7 @@ struct operation
virtual
std
::
string
name
()
const
=
0
;
virtual
std
::
string
name
()
const
=
0
;
virtual
shape
compute_shape
(
std
::
vector
<
shape
>
input
)
const
=
0
;
virtual
shape
compute_shape
(
std
::
vector
<
shape
>
input
)
const
=
0
;
virtual
argument
compute
(
std
::
vector
<
argument
>
input
)
const
=
0
;
virtual
argument
compute
(
std
::
vector
<
argument
>
input
)
const
=
0
;
virtual
std
::
ostream
&
operator_shift_left
(
std
::
ostream
&
os
)
const
=
0
;
};
};
template
<
typename
PrivateDetailTypeErasedT
>
template
<
typename
PrivateDetailTypeErasedT
>
...
@@ -134,14 +152,22 @@ struct operation
...
@@ -134,14 +152,22 @@ struct operation
shape
compute_shape
(
std
::
vector
<
shape
>
input
)
const
override
shape
compute_shape
(
std
::
vector
<
shape
>
input
)
const
override
{
{
return
private_detail_te_value
.
compute_shape
(
std
::
move
(
input
));
return
private_detail_te_value
.
compute_shape
(
std
::
move
(
input
));
}
}
argument
compute
(
std
::
vector
<
argument
>
input
)
const
override
argument
compute
(
std
::
vector
<
argument
>
input
)
const
override
{
{
return
private_detail_te_value
.
compute
(
std
::
move
(
input
));
return
private_detail_te_value
.
compute
(
std
::
move
(
input
));
}
}
std
::
ostream
&
operator_shift_left
(
std
::
ostream
&
os
)
const
override
{
using
rtg
::
operation_stream
::
operator
<<
;
return
os
<<
private_detail_te_value
;
}
PrivateDetailTypeErasedT
private_detail_te_value
;
PrivateDetailTypeErasedT
private_detail_te_value
;
};
};
...
...
include/rtg/operators.hpp
View file @
330fe429
...
@@ -56,6 +56,12 @@ struct convolution
...
@@ -56,6 +56,12 @@ struct convolution
}
}
argument
compute
(
std
::
vector
<
argument
>
)
const
{
RTG_THROW
(
"not computable"
);
}
argument
compute
(
std
::
vector
<
argument
>
)
const
{
RTG_THROW
(
"not computable"
);
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
convolution
&
op
)
{
os
<<
op
.
name
();
return
os
;
}
};
};
struct
pooling
struct
pooling
...
@@ -96,6 +102,12 @@ struct pooling
...
@@ -96,6 +102,12 @@ struct pooling
}
}
argument
compute
(
std
::
vector
<
argument
>
)
const
{
RTG_THROW
(
"not computable"
);
}
argument
compute
(
std
::
vector
<
argument
>
)
const
{
RTG_THROW
(
"not computable"
);
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
pooling
&
op
)
{
os
<<
op
.
name
();
return
os
;
}
};
};
struct
activation
struct
activation
...
@@ -110,6 +122,11 @@ struct activation
...
@@ -110,6 +122,11 @@ struct activation
}
}
argument
compute
(
std
::
vector
<
argument
>
)
const
{
RTG_THROW
(
"not computable"
);
}
argument
compute
(
std
::
vector
<
argument
>
)
const
{
RTG_THROW
(
"not computable"
);
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
activation
&
op
)
{
os
<<
op
.
name
();
return
os
;
}
};
};
struct
reshape
struct
reshape
...
@@ -136,6 +153,12 @@ struct reshape
...
@@ -136,6 +153,12 @@ struct reshape
}
}
argument
compute
(
std
::
vector
<
argument
>
)
const
{
RTG_THROW
(
"not computable"
);
}
argument
compute
(
std
::
vector
<
argument
>
)
const
{
RTG_THROW
(
"not computable"
);
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
reshape
&
op
)
{
os
<<
op
.
name
();
return
os
;
}
};
};
}
// namespace rtg
}
// namespace rtg
...
...
include/rtg/target.hpp
View file @
330fe429
...
@@ -73,14 +73,14 @@ struct target
...
@@ -73,14 +73,14 @@ struct target
std
::
string
name
()
const
std
::
string
name
()
const
{
{
assert
(
private_detail_te_handle_mem_var
);
assert
(
(
*
this
).
private_detail_te_handle_mem_var
);
return
private_detail_te_get_handle
().
name
();
return
(
*
this
).
private_detail_te_get_handle
().
name
();
}
}
void
apply
(
program
&
p
)
const
void
apply
(
program
&
p
)
const
{
{
assert
(
private_detail_te_handle_mem_var
);
assert
(
(
*
this
).
private_detail_te_handle_mem_var
);
return
private_detail_te_get_handle
().
apply
(
p
);
return
(
*
this
).
private_detail_te_get_handle
().
apply
(
p
);
}
}
private:
private:
...
...
test/operation.cpp
View file @
330fe429
...
@@ -10,6 +10,11 @@ struct simple_operation
...
@@ -10,6 +10,11 @@ struct simple_operation
std
::
string
name
()
const
{
return
"simple"
;
}
std
::
string
name
()
const
{
return
"simple"
;
}
rtg
::
shape
compute_shape
(
std
::
vector
<
rtg
::
shape
>
)
const
{
RTG_THROW
(
"not computable"
);
}
rtg
::
shape
compute_shape
(
std
::
vector
<
rtg
::
shape
>
)
const
{
RTG_THROW
(
"not computable"
);
}
rtg
::
argument
compute
(
std
::
vector
<
rtg
::
argument
>
)
const
{
RTG_THROW
(
"not computable"
);
}
rtg
::
argument
compute
(
std
::
vector
<
rtg
::
argument
>
)
const
{
RTG_THROW
(
"not computable"
);
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
simple_operation
&
op
)
{
os
<<
op
.
name
();
return
os
;
}
};
};
void
operation_copy_test
()
void
operation_copy_test
()
...
...
tools/include/operation.hpp
View file @
330fe429
...
@@ -11,11 +11,22 @@
...
@@ -11,11 +11,22 @@
namespace
rtg
{
namespace
rtg
{
namespace
operation_stream
{
template
<
class
T
>
auto
operator
<<
(
std
::
ostream
&
os
,
const
T
&
x
)
->
decltype
(
os
<<
x
.
name
())
{
return
os
<<
x
.
name
();
}
}
// namespace operation_stream
<%
<%
interface
(
'
operation
'
,
interface
(
'
operation
'
,
virtual
(
'
name
'
,
returns
=
'
std
::
string
'
,
const
=
True
),
virtual
(
'
name
'
,
returns
=
'
std
::
string
'
,
const
=
True
),
virtual
(
'
compute_shape
'
,
returns
=
'
shape
'
,
input
=
'
std
::
vector
<
shape
>
'
,
const
=
True
),
virtual
(
'
compute_shape
'
,
returns
=
'
shape
'
,
input
=
'
std
::
vector
<
shape
>
'
,
const
=
True
),
virtual
(
'
compute
'
,
returns
=
'
argument
'
,
input
=
'
std
::
vector
<
argument
>
'
,
const
=
True
)
virtual
(
'
compute
'
,
returns
=
'
argument
'
,
input
=
'
std
::
vector
<
argument
>
'
,
const
=
True
),
friend
(
'
operator
<<
'
,
returns
=
'
std
::
ostream
&
'
,
os
=
'
std
::
ostream
&
'
,
op
=
'
const
operation
&
'
,
using
=
'
rtg
::
operation_stream
::
operator
<<
'
)
)
)
%>
%>
...
...
tools/te.py
View file @
330fe429
...
@@ -161,43 +161,109 @@ inline const ValueType & any_cast(const ${struct_name} & x)
...
@@ -161,43 +161,109 @@ inline const ValueType & any_cast(const ${struct_name} & x)
'''
)
'''
)
nonvirtual_member
=
string
.
Template
(
'''
nonvirtual_member
=
string
.
Template
(
'''
${return_type} ${name}(${params}) ${const}
${friend}
${return_type} ${name}(${params}) ${const}
{
{
assert(private_detail_te_handle_mem_var);
assert(
${this}.
private_detail_te_handle_mem_var);
return private_detail_te_get_handle().${name}(${args});
return
${this}.
private_detail_te_get_handle().${
internal_
name}(${
member_
args});
}
}
'''
)
'''
)
pure_virtual_member
=
string
.
Template
(
"virtual ${return_type} ${name}(${params}) ${const} = 0;
\n
"
)
pure_virtual_member
=
string
.
Template
(
"virtual ${return_type} ${
internal_
name}(${
member_
params}) ${
member_
const} = 0;
\n
"
)
virtual_member
=
string
.
Template
(
'''
virtual_member
=
string
.
Template
(
'''
${return_type} ${name}(${params}) ${const} override
${return_type} ${
internal_
name}(${
member_
params}) ${
member_
const} override
{
{
return private_detail_te_value.${name}(${args});
${using}
return ${call};
}
}
'''
)
'''
)
comment_member
=
string
.
Template
(
'''* ${return_type} ${name}(${params}) ${const};'''
)
comment_member
=
string
.
Template
(
'''*
${friend}
${return_type} ${name}(${params}) ${const};'''
)
def
convert_member
(
d
):
def
trim_type_name
(
name
):
n
=
name
.
strip
()
if
n
.
startswith
(
'const'
):
return
trim_type_name
(
n
[
5
:])
if
n
.
endswith
((
'&'
,
'*'
)):
return
trim_type_name
(
n
[
0
:
-
1
])
return
n
def
internal_name
(
name
):
internal_names
=
{
'operator<<'
:
'operator_shift_left'
,
'operator>>'
:
'operator_shift_right'
,
}
if
name
in
internal_names
:
return
internal_names
[
name
]
else
:
return
name
def
generate_call
(
m
,
friend
):
if
m
[
'name'
].
startswith
(
'operator'
):
op
=
m
[
'name'
][
8
:]
args
=
m
[
'args'
]
if
','
in
args
:
return
args
.
replace
(
','
,
op
)
else
:
return
string
.
Template
(
'${op}${arga}'
).
substitute
(
op
=
op
,
args
=
args
)
if
friend
:
return
string
.
Template
(
'${name}(${args})'
).
substitute
(
m
)
return
string
.
Template
(
'private_detail_te_value.${name}(${args})'
).
substitute
(
m
)
def
convert_member
(
d
,
struct_name
):
for
name
in
d
:
for
name
in
d
:
member
=
{
'name'
:
name
,
'const'
:
''
}
member
=
{
'name'
:
name
,
'internal_name'
:
internal_name
(
name
),
'const'
:
''
,
'member_const'
:
''
,
'friend'
:
''
,
'this'
:
'(*this)'
,
'using'
:
''
}
args
=
[]
args
=
[]
params
=
[]
params
=
[]
member_args
=
[]
member_params
=
[]
skip
=
False
friend
=
False
if
'friend'
in
d
[
name
]:
friend
=
True
skip
=
True
for
x
in
d
[
name
]:
for
x
in
d
[
name
]:
t
=
d
[
name
][
x
]
t
=
d
[
name
][
x
]
if
x
==
'return'
:
if
x
==
'return'
:
member
[
'return_type'
]
=
t
member
[
'return_type'
]
=
t
elif
x
==
'const'
:
elif
x
==
'const'
:
member
[
'const'
]
=
'const'
member
[
'const'
]
=
'const'
member
[
'member_const'
]
=
'const'
elif
x
==
'friend'
:
member
[
'friend'
]
=
'friend'
elif
x
==
'using'
:
member
[
'using'
]
=
'using {};'
.
format
(
d
[
name
][
'using'
])
else
:
else
:
use_member
=
not
(
skip
and
struct_name
==
trim_type_name
(
t
))
arg_name
=
x
if
not
use_member
:
arg_name
=
'private_detail_te_value'
member
[
'this'
]
=
x
if
'const'
in
t
:
member
[
'member_const'
]
=
'const'
if
t
.
endswith
((
'&'
,
'*'
)):
if
t
.
endswith
((
'&'
,
'*'
)):
args
.
append
(
x
)
if
use_member
:
member_args
.
append
(
x
)
args
.
append
(
arg_name
)
else
:
else
:
args
.
append
(
'std::move({})'
.
format
(
x
))
if
use_member
:
member_args
.
append
(
'std::move({})'
.
format
(
x
))
args
.
append
(
'std::move({})'
.
format
(
arg_name
))
params
.
append
(
t
+
' '
+
x
)
params
.
append
(
t
+
' '
+
x
)
if
use_member
:
member_params
.
append
(
t
+
' '
+
x
)
else
:
skip
=
False
member
[
'args'
]
=
','
.
join
(
args
)
member
[
'args'
]
=
','
.
join
(
args
)
member
[
'member_args'
]
=
','
.
join
(
member_args
)
member
[
'params'
]
=
','
.
join
(
params
)
member
[
'params'
]
=
','
.
join
(
params
)
member
[
'params'
]
=
','
.
join
(
params
)
member
[
'member_params'
]
=
','
.
join
(
member_params
)
member
[
'call'
]
=
generate_call
(
member
,
friend
)
return
member
return
member
return
None
return
None
...
@@ -208,7 +274,7 @@ def generate_form(name, members):
...
@@ -208,7 +274,7 @@ def generate_form(name, members):
virtual_members
=
[]
virtual_members
=
[]
comment_members
=
[]
comment_members
=
[]
for
member
in
members
:
for
member
in
members
:
m
=
convert_member
(
member
)
m
=
convert_member
(
member
,
name
)
nonvirtual_members
.
append
(
nonvirtual_member
.
substitute
(
m
))
nonvirtual_members
.
append
(
nonvirtual_member
.
substitute
(
m
))
pure_virtual_members
.
append
(
pure_virtual_member
.
substitute
(
m
))
pure_virtual_members
.
append
(
pure_virtual_member
.
substitute
(
m
))
virtual_members
.
append
(
virtual_member
.
substitute
(
m
))
virtual_members
.
append
(
virtual_member
.
substitute
(
m
))
...
@@ -226,6 +292,12 @@ def virtual(name, returns=None, **kwargs):
...
@@ -226,6 +292,12 @@ def virtual(name, returns=None, **kwargs):
args
[
'return'
]
=
returns
args
[
'return'
]
=
returns
return
{
name
:
args
}
return
{
name
:
args
}
def
friend
(
name
,
returns
=
None
,
**
kwargs
):
args
=
kwargs
args
[
'return'
]
=
returns
args
[
'friend'
]
=
'friend'
return
{
name
:
args
}
def
interface
(
name
,
*
members
):
def
interface
(
name
,
*
members
):
return
generate_form
(
name
,
members
)
return
generate_form
(
name
,
members
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment