Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
cdea8d86
Commit
cdea8d86
authored
May 16, 2018
by
Paul
Browse files
Add stream operator
parent
ab0ea297
Changes
9
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
104 additions
and
23 deletions
+104
-23
frontend/onnx/read_onnx.cpp
frontend/onnx/read_onnx.cpp
+5
-0
include/rtg/builtin.hpp
include/rtg/builtin.hpp
+10
-0
include/rtg/operation.hpp
include/rtg/operation.hpp
+22
-9
include/rtg/operators.hpp
include/rtg/operators.hpp
+23
-0
include/rtg/target.hpp
include/rtg/target.hpp
+4
-4
test/eval_test.cpp
test/eval_test.cpp
+12
-0
test/operation.cpp
test/operation.cpp
+5
-0
tools/include/operation.hpp
tools/include/operation.hpp
+2
-1
tools/te.py
tools/te.py
+21
-9
No files found.
frontend/onnx/read_onnx.cpp
View file @
cdea8d86
...
...
@@ -23,6 +23,11 @@ struct unknown
return
input
.
front
();
}
rtg
::
argument
compute
(
std
::
vector
<
rtg
::
argument
>
)
const
{
RTG_THROW
(
"not computable"
);
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
unknown
&
x
)
{
os
<<
x
.
name
();
return
os
;
}
};
template
<
class
C
,
class
T
>
...
...
include/rtg/builtin.hpp
View file @
cdea8d86
...
...
@@ -13,6 +13,11 @@ struct literal
std
::
string
name
()
const
{
return
"@literal"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
)
const
{
RTG_THROW
(
"builtin"
);
}
argument
compute
(
std
::
vector
<
argument
>
)
const
{
RTG_THROW
(
"builtin"
);
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
literal
&
op
)
{
os
<<
op
.
name
();
return
os
;
}
};
struct
param
...
...
@@ -21,6 +26,11 @@ struct param
std
::
string
name
()
const
{
return
"@param:"
+
parameter
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
)
const
{
RTG_THROW
(
"builtin"
);
}
argument
compute
(
std
::
vector
<
argument
>
)
const
{
RTG_THROW
(
"builtin"
);
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
param
&
op
)
{
os
<<
op
.
name
();
return
os
;
}
};
}
// namespace builtin
...
...
include/rtg/operation.hpp
View file @
cdea8d86
...
...
@@ -19,6 +19,7 @@ namespace rtg {
* std::string name() const;
* shape compute_shape(std::vector<shape> input) const;
* argument compute(std::vector<argument> input) const;
* friend std::ostream & operator<<(std::ostream & os,const operation & op) ;
* };
*
*/
...
...
@@ -74,20 +75,26 @@ struct operation
std
::
string
name
()
const
{
assert
(
private_detail_te_handle_mem_var
);
return
private_detail_te_get_handle
().
name
();
assert
(
(
*
this
).
private_detail_te_handle_mem_var
);
return
(
*
this
).
private_detail_te_get_handle
().
name
();
}
shape
compute_shape
(
std
::
vector
<
shape
>
input
)
const
{
assert
(
private_detail_te_handle_mem_var
);
return
private_detail_te_get_handle
().
compute_shape
(
std
::
move
(
input
));
assert
(
(
*
this
).
private_detail_te_handle_mem_var
);
return
(
*
this
).
private_detail_te_get_handle
().
compute_shape
(
std
::
move
(
input
));
}
argument
compute
(
std
::
vector
<
argument
>
input
)
const
{
assert
(
private_detail_te_handle_mem_var
);
return
private_detail_te_get_handle
().
compute
(
std
::
move
(
input
));
assert
((
*
this
).
private_detail_te_handle_mem_var
);
return
(
*
this
).
private_detail_te_get_handle
().
compute
(
std
::
move
(
input
));
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
operation
&
op
)
{
assert
(
op
.
private_detail_te_handle_mem_var
);
return
op
.
private_detail_te_get_handle
().
operator_shift_left
(
os
);
}
private:
...
...
@@ -100,6 +107,7 @@ struct operation
virtual
std
::
string
name
()
const
=
0
;
virtual
shape
compute_shape
(
std
::
vector
<
shape
>
input
)
const
=
0
;
virtual
argument
compute
(
std
::
vector
<
argument
>
input
)
const
=
0
;
virtual
std
::
ostream
&
operator_shift_left
(
std
::
ostream
&
os
)
const
=
0
;
};
template
<
typename
PrivateDetailTypeErasedT
>
...
...
@@ -142,6 +150,11 @@ struct operation
return
private_detail_te_value
.
compute
(
std
::
move
(
input
));
}
std
::
ostream
&
operator_shift_left
(
std
::
ostream
&
os
)
const
override
{
return
os
<<
private_detail_te_value
;
}
PrivateDetailTypeErasedT
private_detail_te_value
;
};
...
...
include/rtg/operators.hpp
View file @
cdea8d86
...
...
@@ -56,6 +56,12 @@ struct convolution
}
argument
compute
(
std
::
vector
<
argument
>
)
const
{
RTG_THROW
(
"not computable"
);
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
convolution
&
op
)
{
os
<<
op
.
name
();
return
os
;
}
};
struct
pooling
...
...
@@ -96,6 +102,12 @@ struct pooling
}
argument
compute
(
std
::
vector
<
argument
>
)
const
{
RTG_THROW
(
"not computable"
);
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
pooling
&
op
)
{
os
<<
op
.
name
();
return
os
;
}
};
struct
activation
...
...
@@ -110,6 +122,11 @@ struct activation
}
argument
compute
(
std
::
vector
<
argument
>
)
const
{
RTG_THROW
(
"not computable"
);
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
activation
&
op
)
{
os
<<
op
.
name
();
return
os
;
}
};
struct
reshape
...
...
@@ -136,6 +153,12 @@ struct reshape
}
argument
compute
(
std
::
vector
<
argument
>
)
const
{
RTG_THROW
(
"not computable"
);
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
reshape
&
op
)
{
os
<<
op
.
name
();
return
os
;
}
};
}
// namespace rtg
...
...
include/rtg/target.hpp
View file @
cdea8d86
...
...
@@ -73,14 +73,14 @@ struct target
std
::
string
name
()
const
{
assert
(
private_detail_te_handle_mem_var
);
return
private_detail_te_get_handle
().
name
();
assert
(
(
*
this
).
private_detail_te_handle_mem_var
);
return
(
*
this
).
private_detail_te_get_handle
().
name
();
}
void
apply
(
program
&
p
)
const
{
assert
(
private_detail_te_handle_mem_var
);
return
private_detail_te_get_handle
().
apply
(
p
);
assert
(
(
*
this
).
private_detail_te_handle_mem_var
);
return
(
*
this
).
private_detail_te_get_handle
().
apply
(
p
);
}
private:
...
...
test/eval_test.cpp
View file @
cdea8d86
...
...
@@ -31,6 +31,12 @@ struct sum_op
RTG_THROW
(
"Wrong inputs"
);
return
inputs
.
front
();
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
sum_op
&
op
)
{
os
<<
op
.
name
();
return
os
;
}
};
struct
minus_op
...
...
@@ -60,6 +66,12 @@ struct minus_op
RTG_THROW
(
"Wrong inputs"
);
return
inputs
.
front
();
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
minus_op
&
op
)
{
os
<<
op
.
name
();
return
os
;
}
};
struct
id_target
...
...
test/operation.cpp
View file @
cdea8d86
...
...
@@ -10,6 +10,11 @@ struct simple_operation
std
::
string
name
()
const
{
return
"simple"
;
}
rtg
::
shape
compute_shape
(
std
::
vector
<
rtg
::
shape
>
)
const
{
RTG_THROW
(
"not computable"
);
}
rtg
::
argument
compute
(
std
::
vector
<
rtg
::
argument
>
)
const
{
RTG_THROW
(
"not computable"
);
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
simple_operation
&
op
)
{
os
<<
op
.
name
();
return
os
;
}
};
void
operation_copy_test
()
...
...
tools/include/operation.hpp
View file @
cdea8d86
...
...
@@ -15,7 +15,8 @@ namespace rtg {
interface
(
'
operation
'
,
virtual
(
'
name
'
,
returns
=
'
std
::
string
'
,
const
=
True
),
virtual
(
'
compute_shape
'
,
returns
=
'
shape
'
,
input
=
'
std
::
vector
<
shape
>
'
,
const
=
True
),
virtual
(
'
compute
'
,
returns
=
'
argument
'
,
input
=
'
std
::
vector
<
argument
>
'
,
const
=
True
)
virtual
(
'
compute
'
,
returns
=
'
argument
'
,
input
=
'
std
::
vector
<
argument
>
'
,
const
=
True
),
friend
(
'
operator
<<
'
,
returns
=
'
std
::
ostream
&
'
,
os
=
'
std
::
ostream
&
'
,
op
=
'
const
operation
&
'
)
)
%>
...
...
tools/te.py
View file @
cdea8d86
...
...
@@ -163,15 +163,15 @@ inline const ValueType & any_cast(const ${struct_name} & x)
nonvirtual_member
=
string
.
Template
(
'''
${friend} ${return_type} ${name}(${params}) ${const}
{
assert(private_detail_te_handle_mem_var);
return private_detail_te_get_handle().${internal_name}(${member_args});
assert(
${this}.
private_detail_te_handle_mem_var);
return
${this}.
private_detail_te_get_handle().${internal_name}(${member_args});
}
'''
)
pure_virtual_member
=
string
.
Template
(
"virtual ${return_type} ${internal_name}(${member_params}) ${const} = 0;
\n
"
)
pure_virtual_member
=
string
.
Template
(
"virtual ${return_type} ${internal_name}(${member_params}) ${
member_
const} = 0;
\n
"
)
virtual_member
=
string
.
Template
(
'''
${return_type} ${internal_name}(${member_params}) ${const} override
${return_type} ${internal_name}(${member_params}) ${
member_
const} override
{
return ${call};
}
...
...
@@ -201,17 +201,24 @@ def generate_call(m, friend):
if
m
[
'name'
].
startswith
(
'operator'
):
op
=
m
[
'name'
][
8
:]
args
=
m
[
'args'
]
if
len
(
m
[
args
])
==
2
:
return
string
.
Tem
pla
t
e
(
'
${arg1} ${op} ${arg2}'
).
substitute
(
op
=
op
,
arg1
=
args
[
0
],
arg2
=
args
[
1
]
)
if
','
in
args
:
return
args
.
re
pla
c
e
(
'
,'
,
op
)
else
:
return
string
.
Template
(
'${op}${arg
1
}'
).
substitute
(
op
=
op
,
arg
1
=
args
[
0
]
)
return
string
.
Template
(
'${op}${arg
a
}'
).
substitute
(
op
=
op
,
arg
s
=
args
)
if
friend
:
return
string
.
Template
(
'${name}(${args})'
).
substitute
(
m
)
return
string
.
Template
(
'private_detail_te_value.${name}(${args})'
).
substitute
(
m
)
def
convert_member
(
d
,
struct_name
):
for
name
in
d
:
member
=
{
'name'
:
name
,
'internal_name'
:
internal_name
(
name
),
'const'
:
''
,
'friend'
:
''
}
member
=
{
'name'
:
name
,
'internal_name'
:
internal_name
(
name
),
'const'
:
''
,
'member_const'
:
''
,
'friend'
:
''
,
'this'
:
'(*this)'
}
args
=
[]
params
=
[]
member_args
=
[]
...
...
@@ -227,12 +234,17 @@ def convert_member(d, struct_name):
member
[
'return_type'
]
=
t
elif
x
==
'const'
:
member
[
'const'
]
=
'const'
member
[
'member_const'
]
=
'const'
elif
x
==
'friend'
:
member
[
'friend'
]
=
'friend'
else
:
use_member
=
not
(
skip
and
struct_name
==
trim_type_name
(
t
))
arg_name
=
x
if
not
use_member
:
arg_name
=
'private_detail_te_value'
if
not
use_member
:
arg_name
=
'private_detail_te_value'
member
[
'this'
]
=
x
if
'const'
in
t
:
member
[
'member_const'
]
=
'const'
if
t
.
endswith
((
'&'
,
'*'
)):
if
use_member
:
member_args
.
append
(
x
)
args
.
append
(
arg_name
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment