Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
cdea8d86
You need to sign in or sign up before continuing.
Commit
cdea8d86
authored
May 16, 2018
by
Paul
Browse files
Add stream operator
parent
ab0ea297
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
104 additions
and
23 deletions
+104
-23
frontend/onnx/read_onnx.cpp
frontend/onnx/read_onnx.cpp
+5
-0
include/rtg/builtin.hpp
include/rtg/builtin.hpp
+10
-0
include/rtg/operation.hpp
include/rtg/operation.hpp
+22
-9
include/rtg/operators.hpp
include/rtg/operators.hpp
+23
-0
include/rtg/target.hpp
include/rtg/target.hpp
+4
-4
test/eval_test.cpp
test/eval_test.cpp
+12
-0
test/operation.cpp
test/operation.cpp
+5
-0
tools/include/operation.hpp
tools/include/operation.hpp
+2
-1
tools/te.py
tools/te.py
+21
-9
No files found.
frontend/onnx/read_onnx.cpp
View file @
cdea8d86
...
@@ -23,6 +23,11 @@ struct unknown
...
@@ -23,6 +23,11 @@ struct unknown
return
input
.
front
();
return
input
.
front
();
}
}
rtg
::
argument
compute
(
std
::
vector
<
rtg
::
argument
>
)
const
{
RTG_THROW
(
"not computable"
);
}
rtg
::
argument
compute
(
std
::
vector
<
rtg
::
argument
>
)
const
{
RTG_THROW
(
"not computable"
);
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
unknown
&
x
)
{
os
<<
x
.
name
();
return
os
;
}
};
};
template
<
class
C
,
class
T
>
template
<
class
C
,
class
T
>
...
...
include/rtg/builtin.hpp
View file @
cdea8d86
...
@@ -13,6 +13,11 @@ struct literal
...
@@ -13,6 +13,11 @@ struct literal
std
::
string
name
()
const
{
return
"@literal"
;
}
std
::
string
name
()
const
{
return
"@literal"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
)
const
{
RTG_THROW
(
"builtin"
);
}
shape
compute_shape
(
std
::
vector
<
shape
>
)
const
{
RTG_THROW
(
"builtin"
);
}
argument
compute
(
std
::
vector
<
argument
>
)
const
{
RTG_THROW
(
"builtin"
);
}
argument
compute
(
std
::
vector
<
argument
>
)
const
{
RTG_THROW
(
"builtin"
);
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
literal
&
op
)
{
os
<<
op
.
name
();
return
os
;
}
};
};
struct
param
struct
param
...
@@ -21,6 +26,11 @@ struct param
...
@@ -21,6 +26,11 @@ struct param
std
::
string
name
()
const
{
return
"@param:"
+
parameter
;
}
std
::
string
name
()
const
{
return
"@param:"
+
parameter
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
)
const
{
RTG_THROW
(
"builtin"
);
}
shape
compute_shape
(
std
::
vector
<
shape
>
)
const
{
RTG_THROW
(
"builtin"
);
}
argument
compute
(
std
::
vector
<
argument
>
)
const
{
RTG_THROW
(
"builtin"
);
}
argument
compute
(
std
::
vector
<
argument
>
)
const
{
RTG_THROW
(
"builtin"
);
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
param
&
op
)
{
os
<<
op
.
name
();
return
os
;
}
};
};
}
// namespace builtin
}
// namespace builtin
...
...
include/rtg/operation.hpp
View file @
cdea8d86
...
@@ -19,6 +19,7 @@ namespace rtg {
...
@@ -19,6 +19,7 @@ namespace rtg {
* std::string name() const;
* std::string name() const;
* shape compute_shape(std::vector<shape> input) const;
* shape compute_shape(std::vector<shape> input) const;
* argument compute(std::vector<argument> input) const;
* argument compute(std::vector<argument> input) const;
* friend std::ostream & operator<<(std::ostream & os,const operation & op) ;
* };
* };
*
*
*/
*/
...
@@ -74,20 +75,26 @@ struct operation
...
@@ -74,20 +75,26 @@ struct operation
std
::
string
name
()
const
std
::
string
name
()
const
{
{
assert
(
private_detail_te_handle_mem_var
);
assert
(
(
*
this
).
private_detail_te_handle_mem_var
);
return
private_detail_te_get_handle
().
name
();
return
(
*
this
).
private_detail_te_get_handle
().
name
();
}
}
shape
compute_shape
(
std
::
vector
<
shape
>
input
)
const
shape
compute_shape
(
std
::
vector
<
shape
>
input
)
const
{
{
assert
(
private_detail_te_handle_mem_var
);
assert
(
(
*
this
).
private_detail_te_handle_mem_var
);
return
private_detail_te_get_handle
().
compute_shape
(
std
::
move
(
input
));
return
(
*
this
).
private_detail_te_get_handle
().
compute_shape
(
std
::
move
(
input
));
}
}
argument
compute
(
std
::
vector
<
argument
>
input
)
const
argument
compute
(
std
::
vector
<
argument
>
input
)
const
{
{
assert
(
private_detail_te_handle_mem_var
);
assert
((
*
this
).
private_detail_te_handle_mem_var
);
return
private_detail_te_get_handle
().
compute
(
std
::
move
(
input
));
return
(
*
this
).
private_detail_te_get_handle
().
compute
(
std
::
move
(
input
));
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
operation
&
op
)
{
assert
(
op
.
private_detail_te_handle_mem_var
);
return
op
.
private_detail_te_get_handle
().
operator_shift_left
(
os
);
}
}
private:
private:
...
@@ -97,9 +104,10 @@ struct operation
...
@@ -97,9 +104,10 @@ struct operation
virtual
std
::
shared_ptr
<
private_detail_te_handle_base_type
>
clone
()
const
=
0
;
virtual
std
::
shared_ptr
<
private_detail_te_handle_base_type
>
clone
()
const
=
0
;
virtual
const
std
::
type_info
&
type
()
const
=
0
;
virtual
const
std
::
type_info
&
type
()
const
=
0
;
virtual
std
::
string
name
()
const
=
0
;
virtual
std
::
string
name
()
const
=
0
;
virtual
shape
compute_shape
(
std
::
vector
<
shape
>
input
)
const
=
0
;
virtual
shape
compute_shape
(
std
::
vector
<
shape
>
input
)
const
=
0
;
virtual
argument
compute
(
std
::
vector
<
argument
>
input
)
const
=
0
;
virtual
argument
compute
(
std
::
vector
<
argument
>
input
)
const
=
0
;
virtual
std
::
ostream
&
operator_shift_left
(
std
::
ostream
&
os
)
const
=
0
;
};
};
template
<
typename
PrivateDetailTypeErasedT
>
template
<
typename
PrivateDetailTypeErasedT
>
...
@@ -142,6 +150,11 @@ struct operation
...
@@ -142,6 +150,11 @@ struct operation
return
private_detail_te_value
.
compute
(
std
::
move
(
input
));
return
private_detail_te_value
.
compute
(
std
::
move
(
input
));
}
}
std
::
ostream
&
operator_shift_left
(
std
::
ostream
&
os
)
const
override
{
return
os
<<
private_detail_te_value
;
}
PrivateDetailTypeErasedT
private_detail_te_value
;
PrivateDetailTypeErasedT
private_detail_te_value
;
};
};
...
...
include/rtg/operators.hpp
View file @
cdea8d86
...
@@ -56,6 +56,12 @@ struct convolution
...
@@ -56,6 +56,12 @@ struct convolution
}
}
argument
compute
(
std
::
vector
<
argument
>
)
const
{
RTG_THROW
(
"not computable"
);
}
argument
compute
(
std
::
vector
<
argument
>
)
const
{
RTG_THROW
(
"not computable"
);
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
convolution
&
op
)
{
os
<<
op
.
name
();
return
os
;
}
};
};
struct
pooling
struct
pooling
...
@@ -96,6 +102,12 @@ struct pooling
...
@@ -96,6 +102,12 @@ struct pooling
}
}
argument
compute
(
std
::
vector
<
argument
>
)
const
{
RTG_THROW
(
"not computable"
);
}
argument
compute
(
std
::
vector
<
argument
>
)
const
{
RTG_THROW
(
"not computable"
);
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
pooling
&
op
)
{
os
<<
op
.
name
();
return
os
;
}
};
};
struct
activation
struct
activation
...
@@ -110,6 +122,11 @@ struct activation
...
@@ -110,6 +122,11 @@ struct activation
}
}
argument
compute
(
std
::
vector
<
argument
>
)
const
{
RTG_THROW
(
"not computable"
);
}
argument
compute
(
std
::
vector
<
argument
>
)
const
{
RTG_THROW
(
"not computable"
);
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
activation
&
op
)
{
os
<<
op
.
name
();
return
os
;
}
};
};
struct
reshape
struct
reshape
...
@@ -136,6 +153,12 @@ struct reshape
...
@@ -136,6 +153,12 @@ struct reshape
}
}
argument
compute
(
std
::
vector
<
argument
>
)
const
{
RTG_THROW
(
"not computable"
);
}
argument
compute
(
std
::
vector
<
argument
>
)
const
{
RTG_THROW
(
"not computable"
);
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
reshape
&
op
)
{
os
<<
op
.
name
();
return
os
;
}
};
};
}
// namespace rtg
}
// namespace rtg
...
...
include/rtg/target.hpp
View file @
cdea8d86
...
@@ -73,14 +73,14 @@ struct target
...
@@ -73,14 +73,14 @@ struct target
std
::
string
name
()
const
std
::
string
name
()
const
{
{
assert
(
private_detail_te_handle_mem_var
);
assert
(
(
*
this
).
private_detail_te_handle_mem_var
);
return
private_detail_te_get_handle
().
name
();
return
(
*
this
).
private_detail_te_get_handle
().
name
();
}
}
void
apply
(
program
&
p
)
const
void
apply
(
program
&
p
)
const
{
{
assert
(
private_detail_te_handle_mem_var
);
assert
(
(
*
this
).
private_detail_te_handle_mem_var
);
return
private_detail_te_get_handle
().
apply
(
p
);
return
(
*
this
).
private_detail_te_get_handle
().
apply
(
p
);
}
}
private:
private:
...
...
test/eval_test.cpp
View file @
cdea8d86
...
@@ -31,6 +31,12 @@ struct sum_op
...
@@ -31,6 +31,12 @@ struct sum_op
RTG_THROW
(
"Wrong inputs"
);
RTG_THROW
(
"Wrong inputs"
);
return
inputs
.
front
();
return
inputs
.
front
();
}
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
sum_op
&
op
)
{
os
<<
op
.
name
();
return
os
;
}
};
};
struct
minus_op
struct
minus_op
...
@@ -60,6 +66,12 @@ struct minus_op
...
@@ -60,6 +66,12 @@ struct minus_op
RTG_THROW
(
"Wrong inputs"
);
RTG_THROW
(
"Wrong inputs"
);
return
inputs
.
front
();
return
inputs
.
front
();
}
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
minus_op
&
op
)
{
os
<<
op
.
name
();
return
os
;
}
};
};
struct
id_target
struct
id_target
...
...
test/operation.cpp
View file @
cdea8d86
...
@@ -10,6 +10,11 @@ struct simple_operation
...
@@ -10,6 +10,11 @@ struct simple_operation
std
::
string
name
()
const
{
return
"simple"
;
}
std
::
string
name
()
const
{
return
"simple"
;
}
rtg
::
shape
compute_shape
(
std
::
vector
<
rtg
::
shape
>
)
const
{
RTG_THROW
(
"not computable"
);
}
rtg
::
shape
compute_shape
(
std
::
vector
<
rtg
::
shape
>
)
const
{
RTG_THROW
(
"not computable"
);
}
rtg
::
argument
compute
(
std
::
vector
<
rtg
::
argument
>
)
const
{
RTG_THROW
(
"not computable"
);
}
rtg
::
argument
compute
(
std
::
vector
<
rtg
::
argument
>
)
const
{
RTG_THROW
(
"not computable"
);
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
simple_operation
&
op
)
{
os
<<
op
.
name
();
return
os
;
}
};
};
void
operation_copy_test
()
void
operation_copy_test
()
...
...
tools/include/operation.hpp
View file @
cdea8d86
...
@@ -15,7 +15,8 @@ namespace rtg {
...
@@ -15,7 +15,8 @@ namespace rtg {
interface
(
'
operation
'
,
interface
(
'
operation
'
,
virtual
(
'
name
'
,
returns
=
'
std
::
string
'
,
const
=
True
),
virtual
(
'
name
'
,
returns
=
'
std
::
string
'
,
const
=
True
),
virtual
(
'
compute_shape
'
,
returns
=
'
shape
'
,
input
=
'
std
::
vector
<
shape
>
'
,
const
=
True
),
virtual
(
'
compute_shape
'
,
returns
=
'
shape
'
,
input
=
'
std
::
vector
<
shape
>
'
,
const
=
True
),
virtual
(
'
compute
'
,
returns
=
'
argument
'
,
input
=
'
std
::
vector
<
argument
>
'
,
const
=
True
)
virtual
(
'
compute
'
,
returns
=
'
argument
'
,
input
=
'
std
::
vector
<
argument
>
'
,
const
=
True
),
friend
(
'
operator
<<
'
,
returns
=
'
std
::
ostream
&
'
,
os
=
'
std
::
ostream
&
'
,
op
=
'
const
operation
&
'
)
)
)
%>
%>
...
...
tools/te.py
View file @
cdea8d86
...
@@ -163,15 +163,15 @@ inline const ValueType & any_cast(const ${struct_name} & x)
...
@@ -163,15 +163,15 @@ inline const ValueType & any_cast(const ${struct_name} & x)
nonvirtual_member
=
string
.
Template
(
'''
nonvirtual_member
=
string
.
Template
(
'''
${friend} ${return_type} ${name}(${params}) ${const}
${friend} ${return_type} ${name}(${params}) ${const}
{
{
assert(private_detail_te_handle_mem_var);
assert(
${this}.
private_detail_te_handle_mem_var);
return private_detail_te_get_handle().${internal_name}(${member_args});
return
${this}.
private_detail_te_get_handle().${internal_name}(${member_args});
}
}
'''
)
'''
)
pure_virtual_member
=
string
.
Template
(
"virtual ${return_type} ${internal_name}(${member_params}) ${const} = 0;
\n
"
)
pure_virtual_member
=
string
.
Template
(
"virtual ${return_type} ${internal_name}(${member_params}) ${
member_
const} = 0;
\n
"
)
virtual_member
=
string
.
Template
(
'''
virtual_member
=
string
.
Template
(
'''
${return_type} ${internal_name}(${member_params}) ${const} override
${return_type} ${internal_name}(${member_params}) ${
member_
const} override
{
{
return ${call};
return ${call};
}
}
...
@@ -201,17 +201,24 @@ def generate_call(m, friend):
...
@@ -201,17 +201,24 @@ def generate_call(m, friend):
if
m
[
'name'
].
startswith
(
'operator'
):
if
m
[
'name'
].
startswith
(
'operator'
):
op
=
m
[
'name'
][
8
:]
op
=
m
[
'name'
][
8
:]
args
=
m
[
'args'
]
args
=
m
[
'args'
]
if
len
(
m
[
args
])
==
2
:
if
','
in
args
:
return
string
.
Tem
pla
t
e
(
'
${arg1} ${op} ${arg2}'
).
substitute
(
op
=
op
,
arg1
=
args
[
0
],
arg2
=
args
[
1
]
)
return
args
.
re
pla
c
e
(
'
,'
,
op
)
else
:
else
:
return
string
.
Template
(
'${op}${arg
1
}'
).
substitute
(
op
=
op
,
arg
1
=
args
[
0
]
)
return
string
.
Template
(
'${op}${arg
a
}'
).
substitute
(
op
=
op
,
arg
s
=
args
)
if
friend
:
if
friend
:
return
string
.
Template
(
'${name}(${args})'
).
substitute
(
m
)
return
string
.
Template
(
'${name}(${args})'
).
substitute
(
m
)
return
string
.
Template
(
'private_detail_te_value.${name}(${args})'
).
substitute
(
m
)
return
string
.
Template
(
'private_detail_te_value.${name}(${args})'
).
substitute
(
m
)
def
convert_member
(
d
,
struct_name
):
def
convert_member
(
d
,
struct_name
):
for
name
in
d
:
for
name
in
d
:
member
=
{
'name'
:
name
,
'internal_name'
:
internal_name
(
name
),
'const'
:
''
,
'friend'
:
''
}
member
=
{
'name'
:
name
,
'internal_name'
:
internal_name
(
name
),
'const'
:
''
,
'member_const'
:
''
,
'friend'
:
''
,
'this'
:
'(*this)'
}
args
=
[]
args
=
[]
params
=
[]
params
=
[]
member_args
=
[]
member_args
=
[]
...
@@ -227,12 +234,17 @@ def convert_member(d, struct_name):
...
@@ -227,12 +234,17 @@ def convert_member(d, struct_name):
member
[
'return_type'
]
=
t
member
[
'return_type'
]
=
t
elif
x
==
'const'
:
elif
x
==
'const'
:
member
[
'const'
]
=
'const'
member
[
'const'
]
=
'const'
member
[
'member_const'
]
=
'const'
elif
x
==
'friend'
:
elif
x
==
'friend'
:
member
[
'friend'
]
=
'friend'
member
[
'friend'
]
=
'friend'
else
:
else
:
use_member
=
not
(
skip
and
struct_name
==
trim_type_name
(
t
))
use_member
=
not
(
skip
and
struct_name
==
trim_type_name
(
t
))
arg_name
=
x
arg_name
=
x
if
not
use_member
:
arg_name
=
'private_detail_te_value'
if
not
use_member
:
arg_name
=
'private_detail_te_value'
member
[
'this'
]
=
x
if
'const'
in
t
:
member
[
'member_const'
]
=
'const'
if
t
.
endswith
((
'&'
,
'*'
)):
if
t
.
endswith
((
'&'
,
'*'
)):
if
use_member
:
member_args
.
append
(
x
)
if
use_member
:
member_args
.
append
(
x
)
args
.
append
(
arg_name
)
args
.
append
(
arg_name
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment