Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
d21778c6
Commit
d21778c6
authored
May 21, 2018
by
Paul
Browse files
Add shape param to compute
parent
0f318650
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
27 additions
and
28 deletions
+27
-28
src/include/rtg/builtin.hpp
src/include/rtg/builtin.hpp
+2
-2
src/include/rtg/operation.hpp
src/include/rtg/operation.hpp
+9
-9
src/include/rtg/operators.hpp
src/include/rtg/operators.hpp
+5
-5
src/onnx/read_onnx.cpp
src/onnx/read_onnx.cpp
+1
-1
src/program.cpp
src/program.cpp
+1
-1
src/targets/cpu/cpu_target.cpp
src/targets/cpu/cpu_target.cpp
+4
-5
test/eval_test.cpp
test/eval_test.cpp
+2
-2
test/operation.cpp
test/operation.cpp
+2
-2
tools/include/operation.hpp
tools/include/operation.hpp
+1
-1
No files found.
src/include/rtg/builtin.hpp
View file @
d21778c6
...
...
@@ -12,7 +12,7 @@ struct literal
{
std
::
string
name
()
const
{
return
"@literal"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
)
const
{
RTG_THROW
(
"builtin"
);
}
argument
compute
(
std
::
vector
<
argument
>
)
const
{
RTG_THROW
(
"builtin"
);
}
argument
compute
(
shape
,
std
::
vector
<
argument
>
)
const
{
RTG_THROW
(
"builtin"
);
}
};
struct
param
...
...
@@ -20,7 +20,7 @@ struct param
std
::
string
parameter
;
std
::
string
name
()
const
{
return
"@param"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
)
const
{
RTG_THROW
(
"builtin"
);
}
argument
compute
(
std
::
vector
<
argument
>
)
const
{
RTG_THROW
(
"builtin"
);
}
argument
compute
(
shape
,
std
::
vector
<
argument
>
)
const
{
RTG_THROW
(
"builtin"
);
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
param
&
op
)
{
os
<<
op
.
name
()
<<
":"
<<
op
.
parameter
;
...
...
src/include/rtg/operation.hpp
View file @
d21778c6
...
...
@@ -28,7 +28,7 @@ auto operator<<(std::ostream& os, const T& x) -> decltype(os << x.name())
* {
* std::string name() const;
* shape compute_shape(std::vector<shape> input) const;
* argument compute(std::vector<argument> input) const;
* argument compute(
shape output,
std::vector<argument> input) const;
* friend std::ostream & operator<<(std::ostream & os,const operation & op) ;
* };
*
...
...
@@ -95,10 +95,10 @@ struct operation
return
(
*
this
).
private_detail_te_get_handle
().
compute_shape
(
std
::
move
(
input
));
}
argument
compute
(
std
::
vector
<
argument
>
input
)
const
argument
compute
(
shape
output
,
std
::
vector
<
argument
>
input
)
const
{
assert
((
*
this
).
private_detail_te_handle_mem_var
);
return
(
*
this
).
private_detail_te_get_handle
().
compute
(
std
::
move
(
input
));
return
(
*
this
).
private_detail_te_get_handle
().
compute
(
std
::
move
(
output
),
std
::
move
(
input
));
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
operation
&
op
)
...
...
@@ -114,10 +114,10 @@ struct operation
virtual
std
::
shared_ptr
<
private_detail_te_handle_base_type
>
clone
()
const
=
0
;
virtual
const
std
::
type_info
&
type
()
const
=
0
;
virtual
std
::
string
name
()
const
=
0
;
virtual
shape
compute_shape
(
std
::
vector
<
shape
>
input
)
const
=
0
;
virtual
argument
compute
(
std
::
vector
<
argument
>
input
)
const
=
0
;
virtual
std
::
ostream
&
operator_shift_left
(
std
::
ostream
&
os
)
const
=
0
;
virtual
std
::
string
name
()
const
=
0
;
virtual
shape
compute_shape
(
std
::
vector
<
shape
>
input
)
const
=
0
;
virtual
argument
compute
(
shape
output
,
std
::
vector
<
argument
>
input
)
const
=
0
;
virtual
std
::
ostream
&
operator_shift_left
(
std
::
ostream
&
os
)
const
=
0
;
};
template
<
typename
PrivateDetailTypeErasedT
>
...
...
@@ -156,10 +156,10 @@ struct operation
return
private_detail_te_value
.
compute_shape
(
std
::
move
(
input
));
}
argument
compute
(
std
::
vector
<
argument
>
input
)
const
override
argument
compute
(
shape
output
,
std
::
vector
<
argument
>
input
)
const
override
{
return
private_detail_te_value
.
compute
(
std
::
move
(
input
));
return
private_detail_te_value
.
compute
(
std
::
move
(
output
),
std
::
move
(
input
));
}
std
::
ostream
&
operator_shift_left
(
std
::
ostream
&
os
)
const
override
...
...
src/include/rtg/operators.hpp
View file @
d21778c6
...
...
@@ -10,7 +10,7 @@ namespace rtg {
struct
not_computable
{
argument
compute
(
std
::
vector
<
argument
>
)
const
{
RTG_THROW
(
"not computable"
);
}
argument
compute
(
shape
,
std
::
vector
<
argument
>
)
const
{
RTG_THROW
(
"not computable"
);
}
};
struct
convolution
...
...
@@ -52,7 +52,7 @@ struct convolution
}};
}
argument
compute
(
std
::
vector
<
argument
>
)
const
{
RTG_THROW
(
"not computable"
);
}
argument
compute
(
shape
,
std
::
vector
<
argument
>
)
const
{
RTG_THROW
(
"not computable"
);
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
convolution
&
op
)
{
...
...
@@ -98,7 +98,7 @@ struct pooling
}};
}
argument
compute
(
std
::
vector
<
argument
>
)
const
{
RTG_THROW
(
"not computable"
);
}
argument
compute
(
shape
,
std
::
vector
<
argument
>
)
const
{
RTG_THROW
(
"not computable"
);
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
pooling
&
op
)
{
...
...
@@ -122,7 +122,7 @@ struct activation
return
inputs
.
front
();
}
argument
compute
(
std
::
vector
<
argument
>
)
const
{
RTG_THROW
(
"not computable"
);
}
argument
compute
(
shape
,
std
::
vector
<
argument
>
)
const
{
RTG_THROW
(
"not computable"
);
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
activation
&
op
)
{
os
<<
op
.
name
()
<<
":"
<<
op
.
mode
;
...
...
@@ -153,7 +153,7 @@ struct reshape
return
{
inputs
.
front
().
type
(),
rdims
};
}
argument
compute
(
std
::
vector
<
argument
>
)
const
{
RTG_THROW
(
"not computable"
);
}
argument
compute
(
shape
,
std
::
vector
<
argument
>
)
const
{
RTG_THROW
(
"not computable"
);
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
reshape
&
op
)
{
...
...
src/onnx/read_onnx.cpp
View file @
d21778c6
...
...
@@ -22,7 +22,7 @@ struct unknown
else
return
input
.
front
();
}
rtg
::
argument
compute
(
std
::
vector
<
rtg
::
argument
>
)
const
{
RTG_THROW
(
"not computable"
);
}
rtg
::
argument
compute
(
rtg
::
shape
,
std
::
vector
<
rtg
::
argument
>
)
const
{
RTG_THROW
(
"not computable"
);
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
unknown
&
x
)
{
os
<<
x
.
name
();
...
...
src/program.cpp
View file @
d21778c6
...
...
@@ -109,7 +109,7 @@ literal program::eval(std::unordered_map<std::string, argument> params) const
ins
.
arguments
.
end
(),
values
.
begin
(),
[
&
](
instruction_ref
i
)
{
return
results
.
at
(
std
::
addressof
(
*
i
));
});
result
=
ins
.
op
.
compute
(
values
);
result
=
ins
.
op
.
compute
(
ins
.
result
,
values
);
}
results
.
emplace
(
std
::
addressof
(
ins
),
result
);
}
...
...
src/targets/cpu/cpu_target.cpp
View file @
d21778c6
...
...
@@ -13,10 +13,9 @@ struct cpu_convolution
std
::
string
name
()
const
{
return
"cpu::convolution"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
return
op
.
compute_shape
(
inputs
);
}
argument
compute
(
std
::
vector
<
argument
>
args
)
const
argument
compute
(
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
shape
output_shape
=
compute_shape
({
args
[
0
].
get_shape
(),
args
[
1
].
get_shape
()});
argument
result
{
compute_shape
({
args
[
0
].
get_shape
(),
args
[
1
].
get_shape
()})};
argument
result
{
output_shape
};
visit_all
(
result
,
args
[
0
],
args
[
1
])([
&
](
auto
output
,
auto
input
,
auto
weights
)
{
auto
in_n
=
input
.
get_shape
().
lens
()[
0
];
auto
in_c
=
input
.
get_shape
().
lens
()[
1
];
...
...
@@ -53,9 +52,9 @@ struct relu
std
::
string
name
()
const
{
return
"cpu::relu"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
return
inputs
.
front
();
}
argument
compute
(
std
::
vector
<
argument
>
args
)
const
argument
compute
(
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
argument
result
{
args
[
0
].
ge
t_shape
()
};
argument
result
{
outpu
t_shape
};
result
.
visit
([
&
](
auto
output
)
{
args
[
0
].
visit
([
&
](
auto
input
)
{
std
::
transform
(
input
.
begin
(),
input
.
end
(),
output
.
begin
(),
[](
auto
x
)
{
...
...
test/eval_test.cpp
View file @
d21778c6
...
...
@@ -8,7 +8,7 @@
struct
sum_op
{
std
::
string
name
()
const
{
return
"sum"
;
}
rtg
::
argument
compute
(
std
::
vector
<
rtg
::
argument
>
args
)
const
rtg
::
argument
compute
(
rtg
::
shape
,
std
::
vector
<
rtg
::
argument
>
args
)
const
{
rtg
::
argument
result
;
if
(
args
.
size
()
!=
2
)
...
...
@@ -37,7 +37,7 @@ struct sum_op
struct
minus_op
{
std
::
string
name
()
const
{
return
"minus"
;
}
rtg
::
argument
compute
(
std
::
vector
<
rtg
::
argument
>
args
)
const
rtg
::
argument
compute
(
rtg
::
shape
,
std
::
vector
<
rtg
::
argument
>
args
)
const
{
rtg
::
argument
result
;
if
(
args
.
size
()
!=
2
)
...
...
test/operation.cpp
View file @
d21778c6
...
...
@@ -9,7 +9,7 @@ struct simple_operation
int
data
=
1
;
std
::
string
name
()
const
{
return
"simple"
;
}
rtg
::
shape
compute_shape
(
std
::
vector
<
rtg
::
shape
>
)
const
{
RTG_THROW
(
"not computable"
);
}
rtg
::
argument
compute
(
std
::
vector
<
rtg
::
argument
>
)
const
{
RTG_THROW
(
"not computable"
);
}
rtg
::
argument
compute
(
rtg
::
shape
,
std
::
vector
<
rtg
::
argument
>
)
const
{
RTG_THROW
(
"not computable"
);
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
simple_operation
&
op
)
{
os
<<
"["
<<
op
.
name
()
<<
"]"
;
...
...
@@ -21,7 +21,7 @@ struct simple_operation_no_print
{
std
::
string
name
()
const
{
return
"simple"
;
}
rtg
::
shape
compute_shape
(
std
::
vector
<
rtg
::
shape
>
)
const
{
RTG_THROW
(
"not computable"
);
}
rtg
::
argument
compute
(
std
::
vector
<
rtg
::
argument
>
)
const
{
RTG_THROW
(
"not computable"
);
}
rtg
::
argument
compute
(
rtg
::
shape
,
std
::
vector
<
rtg
::
argument
>
)
const
{
RTG_THROW
(
"not computable"
);
}
};
void
operation_copy_test
()
...
...
tools/include/operation.hpp
View file @
d21778c6
...
...
@@ -25,7 +25,7 @@ auto operator<<(std::ostream& os, const T& x) -> decltype(os << x.name())
interface
(
'
operation
'
,
virtual
(
'
name
'
,
returns
=
'
std
::
string
'
,
const
=
True
),
virtual
(
'
compute_shape
'
,
returns
=
'
shape
'
,
input
=
'
std
::
vector
<
shape
>
'
,
const
=
True
),
virtual
(
'
compute
'
,
returns
=
'
argument
'
,
input
=
'
std
::
vector
<
argument
>
'
,
const
=
True
),
virtual
(
'
compute
'
,
returns
=
'
argument
'
,
output
=
'
shape
'
,
input
=
'
std
::
vector
<
argument
>
'
,
const
=
True
),
friend
(
'
operator
<<
'
,
returns
=
'
std
::
ostream
&
'
,
os
=
'
std
::
ostream
&
'
,
op
=
'
const
operation
&
'
,
using
=
'
rtg
::
operation_stream
::
operator
<<
'
)
)
%>
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment