Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
cdea8d86
Commit
cdea8d86
authored
May 16, 2018
by
Paul
Browse files
Add stream operator
parent
ab0ea297
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
104 additions
and
23 deletions
+104
-23
frontend/onnx/read_onnx.cpp
frontend/onnx/read_onnx.cpp
+5
-0
include/rtg/builtin.hpp
include/rtg/builtin.hpp
+10
-0
include/rtg/operation.hpp
include/rtg/operation.hpp
+22
-9
include/rtg/operators.hpp
include/rtg/operators.hpp
+23
-0
include/rtg/target.hpp
include/rtg/target.hpp
+4
-4
test/eval_test.cpp
test/eval_test.cpp
+12
-0
test/operation.cpp
test/operation.cpp
+5
-0
tools/include/operation.hpp
tools/include/operation.hpp
+2
-1
tools/te.py
tools/te.py
+21
-9
No files found.
frontend/onnx/read_onnx.cpp
View file @
cdea8d86
...
@@ -23,6 +23,11 @@ struct unknown
...
@@ -23,6 +23,11 @@ struct unknown
return
input
.
front
();
return
input
.
front
();
}
}
rtg
::
argument
compute
(
std
::
vector
<
rtg
::
argument
>
)
const
{
RTG_THROW
(
"not computable"
);
}
rtg
::
argument
compute
(
std
::
vector
<
rtg
::
argument
>
)
const
{
RTG_THROW
(
"not computable"
);
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
unknown
&
x
)
{
os
<<
x
.
name
();
return
os
;
}
};
};
template
<
class
C
,
class
T
>
template
<
class
C
,
class
T
>
...
...
include/rtg/builtin.hpp
View file @
cdea8d86
...
@@ -13,6 +13,11 @@ struct literal
...
@@ -13,6 +13,11 @@ struct literal
std
::
string
name
()
const
{
return
"@literal"
;
}
std
::
string
name
()
const
{
return
"@literal"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
)
const
{
RTG_THROW
(
"builtin"
);
}
shape
compute_shape
(
std
::
vector
<
shape
>
)
const
{
RTG_THROW
(
"builtin"
);
}
argument
compute
(
std
::
vector
<
argument
>
)
const
{
RTG_THROW
(
"builtin"
);
}
argument
compute
(
std
::
vector
<
argument
>
)
const
{
RTG_THROW
(
"builtin"
);
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
literal
&
op
)
{
os
<<
op
.
name
();
return
os
;
}
};
};
struct
param
struct
param
...
@@ -21,6 +26,11 @@ struct param
...
@@ -21,6 +26,11 @@ struct param
std
::
string
name
()
const
{
return
"@param:"
+
parameter
;
}
std
::
string
name
()
const
{
return
"@param:"
+
parameter
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
)
const
{
RTG_THROW
(
"builtin"
);
}
shape
compute_shape
(
std
::
vector
<
shape
>
)
const
{
RTG_THROW
(
"builtin"
);
}
argument
compute
(
std
::
vector
<
argument
>
)
const
{
RTG_THROW
(
"builtin"
);
}
argument
compute
(
std
::
vector
<
argument
>
)
const
{
RTG_THROW
(
"builtin"
);
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
param
&
op
)
{
os
<<
op
.
name
();
return
os
;
}
};
};
}
// namespace builtin
}
// namespace builtin
...
...
include/rtg/operation.hpp
View file @
cdea8d86
...
@@ -19,6 +19,7 @@ namespace rtg {
...
@@ -19,6 +19,7 @@ namespace rtg {
* std::string name() const;
* std::string name() const;
* shape compute_shape(std::vector<shape> input) const;
* shape compute_shape(std::vector<shape> input) const;
* argument compute(std::vector<argument> input) const;
* argument compute(std::vector<argument> input) const;
* friend std::ostream & operator<<(std::ostream & os,const operation & op) ;
* };
* };
*
*
*/
*/
...
@@ -74,20 +75,26 @@ struct operation
...
@@ -74,20 +75,26 @@ struct operation
std
::
string
name
()
const
std
::
string
name
()
const
{
{
assert
(
private_detail_te_handle_mem_var
);
assert
(
(
*
this
).
private_detail_te_handle_mem_var
);
return
private_detail_te_get_handle
().
name
();
return
(
*
this
).
private_detail_te_get_handle
().
name
();
}
}
shape
compute_shape
(
std
::
vector
<
shape
>
input
)
const
shape
compute_shape
(
std
::
vector
<
shape
>
input
)
const
{
{
assert
(
private_detail_te_handle_mem_var
);
assert
(
(
*
this
).
private_detail_te_handle_mem_var
);
return
private_detail_te_get_handle
().
compute_shape
(
std
::
move
(
input
));
return
(
*
this
).
private_detail_te_get_handle
().
compute_shape
(
std
::
move
(
input
));
}
}
argument
compute
(
std
::
vector
<
argument
>
input
)
const
argument
compute
(
std
::
vector
<
argument
>
input
)
const
{
{
assert
(
private_detail_te_handle_mem_var
);
assert
((
*
this
).
private_detail_te_handle_mem_var
);
return
private_detail_te_get_handle
().
compute
(
std
::
move
(
input
));
return
(
*
this
).
private_detail_te_get_handle
().
compute
(
std
::
move
(
input
));
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
operation
&
op
)
{
assert
(
op
.
private_detail_te_handle_mem_var
);
return
op
.
private_detail_te_get_handle
().
operator_shift_left
(
os
);
}
}
private:
private:
...
@@ -97,9 +104,10 @@ struct operation
...
@@ -97,9 +104,10 @@ struct operation
virtual
std
::
shared_ptr
<
private_detail_te_handle_base_type
>
clone
()
const
=
0
;
virtual
std
::
shared_ptr
<
private_detail_te_handle_base_type
>
clone
()
const
=
0
;
virtual
const
std
::
type_info
&
type
()
const
=
0
;
virtual
const
std
::
type_info
&
type
()
const
=
0
;
virtual
std
::
string
name
()
const
=
0
;
virtual
std
::
string
name
()
const
=
0
;
virtual
shape
compute_shape
(
std
::
vector
<
shape
>
input
)
const
=
0
;
virtual
shape
compute_shape
(
std
::
vector
<
shape
>
input
)
const
=
0
;
virtual
argument
compute
(
std
::
vector
<
argument
>
input
)
const
=
0
;
virtual
argument
compute
(
std
::
vector
<
argument
>
input
)
const
=
0
;
virtual
std
::
ostream
&
operator_shift_left
(
std
::
ostream
&
os
)
const
=
0
;
};
};
template
<
typename
PrivateDetailTypeErasedT
>
template
<
typename
PrivateDetailTypeErasedT
>
...
@@ -142,6 +150,11 @@ struct operation
...
@@ -142,6 +150,11 @@ struct operation
return
private_detail_te_value
.
compute
(
std
::
move
(
input
));
return
private_detail_te_value
.
compute
(
std
::
move
(
input
));
}
}
std
::
ostream
&
operator_shift_left
(
std
::
ostream
&
os
)
const
override
{
return
os
<<
private_detail_te_value
;
}
PrivateDetailTypeErasedT
private_detail_te_value
;
PrivateDetailTypeErasedT
private_detail_te_value
;
};
};
...
...
include/rtg/operators.hpp
View file @
cdea8d86
...
@@ -56,6 +56,12 @@ struct convolution
...
@@ -56,6 +56,12 @@ struct convolution
}
}
argument
compute
(
std
::
vector
<
argument
>
)
const
{
RTG_THROW
(
"not computable"
);
}
argument
compute
(
std
::
vector
<
argument
>
)
const
{
RTG_THROW
(
"not computable"
);
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
convolution
&
op
)
{
os
<<
op
.
name
();
return
os
;
}
};
};
struct
pooling
struct
pooling
...
@@ -96,6 +102,12 @@ struct pooling
...
@@ -96,6 +102,12 @@ struct pooling
}
}
argument
compute
(
std
::
vector
<
argument
>
)
const
{
RTG_THROW
(
"not computable"
);
}
argument
compute
(
std
::
vector
<
argument
>
)
const
{
RTG_THROW
(
"not computable"
);
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
pooling
&
op
)
{
os
<<
op
.
name
();
return
os
;
}
};
};
struct
activation
struct
activation
...
@@ -110,6 +122,11 @@ struct activation
...
@@ -110,6 +122,11 @@ struct activation
}
}
argument
compute
(
std
::
vector
<
argument
>
)
const
{
RTG_THROW
(
"not computable"
);
}
argument
compute
(
std
::
vector
<
argument
>
)
const
{
RTG_THROW
(
"not computable"
);
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
activation
&
op
)
{
os
<<
op
.
name
();
return
os
;
}
};
};
struct
reshape
struct
reshape
...
@@ -136,6 +153,12 @@ struct reshape
...
@@ -136,6 +153,12 @@ struct reshape
}
}
argument
compute
(
std
::
vector
<
argument
>
)
const
{
RTG_THROW
(
"not computable"
);
}
argument
compute
(
std
::
vector
<
argument
>
)
const
{
RTG_THROW
(
"not computable"
);
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
reshape
&
op
)
{
os
<<
op
.
name
();
return
os
;
}
};
};
}
// namespace rtg
}
// namespace rtg
...
...
include/rtg/target.hpp
View file @
cdea8d86
...
@@ -73,14 +73,14 @@ struct target
...
@@ -73,14 +73,14 @@ struct target
std
::
string
name
()
const
std
::
string
name
()
const
{
{
assert
(
private_detail_te_handle_mem_var
);
assert
(
(
*
this
).
private_detail_te_handle_mem_var
);
return
private_detail_te_get_handle
().
name
();
return
(
*
this
).
private_detail_te_get_handle
().
name
();
}
}
void
apply
(
program
&
p
)
const
void
apply
(
program
&
p
)
const
{
{
assert
(
private_detail_te_handle_mem_var
);
assert
(
(
*
this
).
private_detail_te_handle_mem_var
);
return
private_detail_te_get_handle
().
apply
(
p
);
return
(
*
this
).
private_detail_te_get_handle
().
apply
(
p
);
}
}
private:
private:
...
...
test/eval_test.cpp
View file @
cdea8d86
...
@@ -31,6 +31,12 @@ struct sum_op
...
@@ -31,6 +31,12 @@ struct sum_op
RTG_THROW
(
"Wrong inputs"
);
RTG_THROW
(
"Wrong inputs"
);
return
inputs
.
front
();
return
inputs
.
front
();
}
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
sum_op
&
op
)
{
os
<<
op
.
name
();
return
os
;
}
};
};
struct
minus_op
struct
minus_op
...
@@ -60,6 +66,12 @@ struct minus_op
...
@@ -60,6 +66,12 @@ struct minus_op
RTG_THROW
(
"Wrong inputs"
);
RTG_THROW
(
"Wrong inputs"
);
return
inputs
.
front
();
return
inputs
.
front
();
}
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
minus_op
&
op
)
{
os
<<
op
.
name
();
return
os
;
}
};
};
struct
id_target
struct
id_target
...
...
test/operation.cpp
View file @
cdea8d86
...
@@ -10,6 +10,11 @@ struct simple_operation
...
@@ -10,6 +10,11 @@ struct simple_operation
std
::
string
name
()
const
{
return
"simple"
;
}
std
::
string
name
()
const
{
return
"simple"
;
}
rtg
::
shape
compute_shape
(
std
::
vector
<
rtg
::
shape
>
)
const
{
RTG_THROW
(
"not computable"
);
}
rtg
::
shape
compute_shape
(
std
::
vector
<
rtg
::
shape
>
)
const
{
RTG_THROW
(
"not computable"
);
}
rtg
::
argument
compute
(
std
::
vector
<
rtg
::
argument
>
)
const
{
RTG_THROW
(
"not computable"
);
}
rtg
::
argument
compute
(
std
::
vector
<
rtg
::
argument
>
)
const
{
RTG_THROW
(
"not computable"
);
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
simple_operation
&
op
)
{
os
<<
op
.
name
();
return
os
;
}
};
};
void
operation_copy_test
()
void
operation_copy_test
()
...
...
tools/include/operation.hpp
View file @
cdea8d86
...
@@ -15,7 +15,8 @@ namespace rtg {
...
@@ -15,7 +15,8 @@ namespace rtg {
interface
(
'
operation
'
,
interface
(
'
operation
'
,
virtual
(
'
name
'
,
returns
=
'
std
::
string
'
,
const
=
True
),
virtual
(
'
name
'
,
returns
=
'
std
::
string
'
,
const
=
True
),
virtual
(
'
compute_shape
'
,
returns
=
'
shape
'
,
input
=
'
std
::
vector
<
shape
>
'
,
const
=
True
),
virtual
(
'
compute_shape
'
,
returns
=
'
shape
'
,
input
=
'
std
::
vector
<
shape
>
'
,
const
=
True
),
virtual
(
'
compute
'
,
returns
=
'
argument
'
,
input
=
'
std
::
vector
<
argument
>
'
,
const
=
True
)
virtual
(
'
compute
'
,
returns
=
'
argument
'
,
input
=
'
std
::
vector
<
argument
>
'
,
const
=
True
),
friend
(
'
operator
<<
'
,
returns
=
'
std
::
ostream
&
'
,
os
=
'
std
::
ostream
&
'
,
op
=
'
const
operation
&
'
)
)
)
%>
%>
...
...
tools/te.py
View file @
cdea8d86
...
@@ -163,15 +163,15 @@ inline const ValueType & any_cast(const ${struct_name} & x)
...
@@ -163,15 +163,15 @@ inline const ValueType & any_cast(const ${struct_name} & x)
nonvirtual_member
=
string
.
Template
(
'''
nonvirtual_member
=
string
.
Template
(
'''
${friend} ${return_type} ${name}(${params}) ${const}
${friend} ${return_type} ${name}(${params}) ${const}
{
{
assert(private_detail_te_handle_mem_var);
assert(
${this}.
private_detail_te_handle_mem_var);
return private_detail_te_get_handle().${internal_name}(${member_args});
return
${this}.
private_detail_te_get_handle().${internal_name}(${member_args});
}
}
'''
)
'''
)
pure_virtual_member
=
string
.
Template
(
"virtual ${return_type} ${internal_name}(${member_params}) ${const} = 0;
\n
"
)
pure_virtual_member
=
string
.
Template
(
"virtual ${return_type} ${internal_name}(${member_params}) ${
member_
const} = 0;
\n
"
)
virtual_member
=
string
.
Template
(
'''
virtual_member
=
string
.
Template
(
'''
${return_type} ${internal_name}(${member_params}) ${const} override
${return_type} ${internal_name}(${member_params}) ${
member_
const} override
{
{
return ${call};
return ${call};
}
}
...
@@ -201,17 +201,24 @@ def generate_call(m, friend):
...
@@ -201,17 +201,24 @@ def generate_call(m, friend):
if
m
[
'name'
].
startswith
(
'operator'
):
if
m
[
'name'
].
startswith
(
'operator'
):
op
=
m
[
'name'
][
8
:]
op
=
m
[
'name'
][
8
:]
args
=
m
[
'args'
]
args
=
m
[
'args'
]
if
len
(
m
[
args
])
==
2
:
if
','
in
args
:
return
string
.
Tem
pla
t
e
(
'
${arg1} ${op} ${arg2}'
).
substitute
(
op
=
op
,
arg1
=
args
[
0
],
arg2
=
args
[
1
]
)
return
args
.
re
pla
c
e
(
'
,'
,
op
)
else
:
else
:
return
string
.
Template
(
'${op}${arg
1
}'
).
substitute
(
op
=
op
,
arg
1
=
args
[
0
]
)
return
string
.
Template
(
'${op}${arg
a
}'
).
substitute
(
op
=
op
,
arg
s
=
args
)
if
friend
:
if
friend
:
return
string
.
Template
(
'${name}(${args})'
).
substitute
(
m
)
return
string
.
Template
(
'${name}(${args})'
).
substitute
(
m
)
return
string
.
Template
(
'private_detail_te_value.${name}(${args})'
).
substitute
(
m
)
return
string
.
Template
(
'private_detail_te_value.${name}(${args})'
).
substitute
(
m
)
def
convert_member
(
d
,
struct_name
):
def
convert_member
(
d
,
struct_name
):
for
name
in
d
:
for
name
in
d
:
member
=
{
'name'
:
name
,
'internal_name'
:
internal_name
(
name
),
'const'
:
''
,
'friend'
:
''
}
member
=
{
'name'
:
name
,
'internal_name'
:
internal_name
(
name
),
'const'
:
''
,
'member_const'
:
''
,
'friend'
:
''
,
'this'
:
'(*this)'
}
args
=
[]
args
=
[]
params
=
[]
params
=
[]
member_args
=
[]
member_args
=
[]
...
@@ -227,12 +234,17 @@ def convert_member(d, struct_name):
...
@@ -227,12 +234,17 @@ def convert_member(d, struct_name):
member
[
'return_type'
]
=
t
member
[
'return_type'
]
=
t
elif
x
==
'const'
:
elif
x
==
'const'
:
member
[
'const'
]
=
'const'
member
[
'const'
]
=
'const'
member
[
'member_const'
]
=
'const'
elif
x
==
'friend'
:
elif
x
==
'friend'
:
member
[
'friend'
]
=
'friend'
member
[
'friend'
]
=
'friend'
else
:
else
:
use_member
=
not
(
skip
and
struct_name
==
trim_type_name
(
t
))
use_member
=
not
(
skip
and
struct_name
==
trim_type_name
(
t
))
arg_name
=
x
arg_name
=
x
if
not
use_member
:
arg_name
=
'private_detail_te_value'
if
not
use_member
:
arg_name
=
'private_detail_te_value'
member
[
'this'
]
=
x
if
'const'
in
t
:
member
[
'member_const'
]
=
'const'
if
t
.
endswith
((
'&'
,
'*'
)):
if
t
.
endswith
((
'&'
,
'*'
)):
if
use_member
:
member_args
.
append
(
x
)
if
use_member
:
member_args
.
append
(
x
)
args
.
append
(
arg_name
)
args
.
append
(
arg_name
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment