Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
d21778c6
Commit
d21778c6
authored
May 21, 2018
by
Paul
Browse files
Add shape param to compute
parent
0f318650
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
27 additions
and
28 deletions
+27
-28
src/include/rtg/builtin.hpp
src/include/rtg/builtin.hpp
+2
-2
src/include/rtg/operation.hpp
src/include/rtg/operation.hpp
+9
-9
src/include/rtg/operators.hpp
src/include/rtg/operators.hpp
+5
-5
src/onnx/read_onnx.cpp
src/onnx/read_onnx.cpp
+1
-1
src/program.cpp
src/program.cpp
+1
-1
src/targets/cpu/cpu_target.cpp
src/targets/cpu/cpu_target.cpp
+4
-5
test/eval_test.cpp
test/eval_test.cpp
+2
-2
test/operation.cpp
test/operation.cpp
+2
-2
tools/include/operation.hpp
tools/include/operation.hpp
+1
-1
No files found.
src/include/rtg/builtin.hpp
View file @
d21778c6
...
@@ -12,7 +12,7 @@ struct literal
...
@@ -12,7 +12,7 @@ struct literal
{
{
std
::
string
name
()
const
{
return
"@literal"
;
}
std
::
string
name
()
const
{
return
"@literal"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
)
const
{
RTG_THROW
(
"builtin"
);
}
shape
compute_shape
(
std
::
vector
<
shape
>
)
const
{
RTG_THROW
(
"builtin"
);
}
argument
compute
(
std
::
vector
<
argument
>
)
const
{
RTG_THROW
(
"builtin"
);
}
argument
compute
(
shape
,
std
::
vector
<
argument
>
)
const
{
RTG_THROW
(
"builtin"
);
}
};
};
struct
param
struct
param
...
@@ -20,7 +20,7 @@ struct param
...
@@ -20,7 +20,7 @@ struct param
std
::
string
parameter
;
std
::
string
parameter
;
std
::
string
name
()
const
{
return
"@param"
;
}
std
::
string
name
()
const
{
return
"@param"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
)
const
{
RTG_THROW
(
"builtin"
);
}
shape
compute_shape
(
std
::
vector
<
shape
>
)
const
{
RTG_THROW
(
"builtin"
);
}
argument
compute
(
std
::
vector
<
argument
>
)
const
{
RTG_THROW
(
"builtin"
);
}
argument
compute
(
shape
,
std
::
vector
<
argument
>
)
const
{
RTG_THROW
(
"builtin"
);
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
param
&
op
)
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
param
&
op
)
{
{
os
<<
op
.
name
()
<<
":"
<<
op
.
parameter
;
os
<<
op
.
name
()
<<
":"
<<
op
.
parameter
;
...
...
src/include/rtg/operation.hpp
View file @
d21778c6
...
@@ -28,7 +28,7 @@ auto operator<<(std::ostream& os, const T& x) -> decltype(os << x.name())
...
@@ -28,7 +28,7 @@ auto operator<<(std::ostream& os, const T& x) -> decltype(os << x.name())
* {
* {
* std::string name() const;
* std::string name() const;
* shape compute_shape(std::vector<shape> input) const;
* shape compute_shape(std::vector<shape> input) const;
* argument compute(std::vector<argument> input) const;
* argument compute(
shape output,
std::vector<argument> input) const;
* friend std::ostream & operator<<(std::ostream & os,const operation & op) ;
* friend std::ostream & operator<<(std::ostream & os,const operation & op) ;
* };
* };
*
*
...
@@ -95,10 +95,10 @@ struct operation
...
@@ -95,10 +95,10 @@ struct operation
return
(
*
this
).
private_detail_te_get_handle
().
compute_shape
(
std
::
move
(
input
));
return
(
*
this
).
private_detail_te_get_handle
().
compute_shape
(
std
::
move
(
input
));
}
}
argument
compute
(
std
::
vector
<
argument
>
input
)
const
argument
compute
(
shape
output
,
std
::
vector
<
argument
>
input
)
const
{
{
assert
((
*
this
).
private_detail_te_handle_mem_var
);
assert
((
*
this
).
private_detail_te_handle_mem_var
);
return
(
*
this
).
private_detail_te_get_handle
().
compute
(
std
::
move
(
input
));
return
(
*
this
).
private_detail_te_get_handle
().
compute
(
std
::
move
(
output
),
std
::
move
(
input
));
}
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
operation
&
op
)
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
operation
&
op
)
...
@@ -114,10 +114,10 @@ struct operation
...
@@ -114,10 +114,10 @@ struct operation
virtual
std
::
shared_ptr
<
private_detail_te_handle_base_type
>
clone
()
const
=
0
;
virtual
std
::
shared_ptr
<
private_detail_te_handle_base_type
>
clone
()
const
=
0
;
virtual
const
std
::
type_info
&
type
()
const
=
0
;
virtual
const
std
::
type_info
&
type
()
const
=
0
;
virtual
std
::
string
name
()
const
=
0
;
virtual
std
::
string
name
()
const
=
0
;
virtual
shape
compute_shape
(
std
::
vector
<
shape
>
input
)
const
=
0
;
virtual
shape
compute_shape
(
std
::
vector
<
shape
>
input
)
const
=
0
;
virtual
argument
compute
(
std
::
vector
<
argument
>
input
)
const
=
0
;
virtual
argument
compute
(
shape
output
,
std
::
vector
<
argument
>
input
)
const
=
0
;
virtual
std
::
ostream
&
operator_shift_left
(
std
::
ostream
&
os
)
const
=
0
;
virtual
std
::
ostream
&
operator_shift_left
(
std
::
ostream
&
os
)
const
=
0
;
};
};
template
<
typename
PrivateDetailTypeErasedT
>
template
<
typename
PrivateDetailTypeErasedT
>
...
@@ -156,10 +156,10 @@ struct operation
...
@@ -156,10 +156,10 @@ struct operation
return
private_detail_te_value
.
compute_shape
(
std
::
move
(
input
));
return
private_detail_te_value
.
compute_shape
(
std
::
move
(
input
));
}
}
argument
compute
(
std
::
vector
<
argument
>
input
)
const
override
argument
compute
(
shape
output
,
std
::
vector
<
argument
>
input
)
const
override
{
{
return
private_detail_te_value
.
compute
(
std
::
move
(
input
));
return
private_detail_te_value
.
compute
(
std
::
move
(
output
),
std
::
move
(
input
));
}
}
std
::
ostream
&
operator_shift_left
(
std
::
ostream
&
os
)
const
override
std
::
ostream
&
operator_shift_left
(
std
::
ostream
&
os
)
const
override
...
...
src/include/rtg/operators.hpp
View file @
d21778c6
...
@@ -10,7 +10,7 @@ namespace rtg {
...
@@ -10,7 +10,7 @@ namespace rtg {
struct
not_computable
struct
not_computable
{
{
argument
compute
(
std
::
vector
<
argument
>
)
const
{
RTG_THROW
(
"not computable"
);
}
argument
compute
(
shape
,
std
::
vector
<
argument
>
)
const
{
RTG_THROW
(
"not computable"
);
}
};
};
struct
convolution
struct
convolution
...
@@ -52,7 +52,7 @@ struct convolution
...
@@ -52,7 +52,7 @@ struct convolution
}};
}};
}
}
argument
compute
(
std
::
vector
<
argument
>
)
const
{
RTG_THROW
(
"not computable"
);
}
argument
compute
(
shape
,
std
::
vector
<
argument
>
)
const
{
RTG_THROW
(
"not computable"
);
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
convolution
&
op
)
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
convolution
&
op
)
{
{
...
@@ -98,7 +98,7 @@ struct pooling
...
@@ -98,7 +98,7 @@ struct pooling
}};
}};
}
}
argument
compute
(
std
::
vector
<
argument
>
)
const
{
RTG_THROW
(
"not computable"
);
}
argument
compute
(
shape
,
std
::
vector
<
argument
>
)
const
{
RTG_THROW
(
"not computable"
);
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
pooling
&
op
)
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
pooling
&
op
)
{
{
...
@@ -122,7 +122,7 @@ struct activation
...
@@ -122,7 +122,7 @@ struct activation
return
inputs
.
front
();
return
inputs
.
front
();
}
}
argument
compute
(
std
::
vector
<
argument
>
)
const
{
RTG_THROW
(
"not computable"
);
}
argument
compute
(
shape
,
std
::
vector
<
argument
>
)
const
{
RTG_THROW
(
"not computable"
);
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
activation
&
op
)
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
activation
&
op
)
{
{
os
<<
op
.
name
()
<<
":"
<<
op
.
mode
;
os
<<
op
.
name
()
<<
":"
<<
op
.
mode
;
...
@@ -153,7 +153,7 @@ struct reshape
...
@@ -153,7 +153,7 @@ struct reshape
return
{
inputs
.
front
().
type
(),
rdims
};
return
{
inputs
.
front
().
type
(),
rdims
};
}
}
argument
compute
(
std
::
vector
<
argument
>
)
const
{
RTG_THROW
(
"not computable"
);
}
argument
compute
(
shape
,
std
::
vector
<
argument
>
)
const
{
RTG_THROW
(
"not computable"
);
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
reshape
&
op
)
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
reshape
&
op
)
{
{
...
...
src/onnx/read_onnx.cpp
View file @
d21778c6
...
@@ -22,7 +22,7 @@ struct unknown
...
@@ -22,7 +22,7 @@ struct unknown
else
else
return
input
.
front
();
return
input
.
front
();
}
}
rtg
::
argument
compute
(
std
::
vector
<
rtg
::
argument
>
)
const
{
RTG_THROW
(
"not computable"
);
}
rtg
::
argument
compute
(
rtg
::
shape
,
std
::
vector
<
rtg
::
argument
>
)
const
{
RTG_THROW
(
"not computable"
);
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
unknown
&
x
)
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
unknown
&
x
)
{
{
os
<<
x
.
name
();
os
<<
x
.
name
();
...
...
src/program.cpp
View file @
d21778c6
...
@@ -109,7 +109,7 @@ literal program::eval(std::unordered_map<std::string, argument> params) const
...
@@ -109,7 +109,7 @@ literal program::eval(std::unordered_map<std::string, argument> params) const
ins
.
arguments
.
end
(),
ins
.
arguments
.
end
(),
values
.
begin
(),
values
.
begin
(),
[
&
](
instruction_ref
i
)
{
return
results
.
at
(
std
::
addressof
(
*
i
));
});
[
&
](
instruction_ref
i
)
{
return
results
.
at
(
std
::
addressof
(
*
i
));
});
result
=
ins
.
op
.
compute
(
values
);
result
=
ins
.
op
.
compute
(
ins
.
result
,
values
);
}
}
results
.
emplace
(
std
::
addressof
(
ins
),
result
);
results
.
emplace
(
std
::
addressof
(
ins
),
result
);
}
}
...
...
src/targets/cpu/cpu_target.cpp
View file @
d21778c6
...
@@ -13,10 +13,9 @@ struct cpu_convolution
...
@@ -13,10 +13,9 @@ struct cpu_convolution
std
::
string
name
()
const
{
return
"cpu::convolution"
;
}
std
::
string
name
()
const
{
return
"cpu::convolution"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
return
op
.
compute_shape
(
inputs
);
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
return
op
.
compute_shape
(
inputs
);
}
argument
compute
(
std
::
vector
<
argument
>
args
)
const
argument
compute
(
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
{
shape
output_shape
=
compute_shape
({
args
[
0
].
get_shape
(),
args
[
1
].
get_shape
()});
argument
result
{
output_shape
};
argument
result
{
compute_shape
({
args
[
0
].
get_shape
(),
args
[
1
].
get_shape
()})};
visit_all
(
result
,
args
[
0
],
args
[
1
])([
&
](
auto
output
,
auto
input
,
auto
weights
)
{
visit_all
(
result
,
args
[
0
],
args
[
1
])([
&
](
auto
output
,
auto
input
,
auto
weights
)
{
auto
in_n
=
input
.
get_shape
().
lens
()[
0
];
auto
in_n
=
input
.
get_shape
().
lens
()[
0
];
auto
in_c
=
input
.
get_shape
().
lens
()[
1
];
auto
in_c
=
input
.
get_shape
().
lens
()[
1
];
...
@@ -53,9 +52,9 @@ struct relu
...
@@ -53,9 +52,9 @@ struct relu
std
::
string
name
()
const
{
return
"cpu::relu"
;
}
std
::
string
name
()
const
{
return
"cpu::relu"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
return
inputs
.
front
();
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
return
inputs
.
front
();
}
argument
compute
(
std
::
vector
<
argument
>
args
)
const
argument
compute
(
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
{
argument
result
{
args
[
0
].
ge
t_shape
()
};
argument
result
{
outpu
t_shape
};
result
.
visit
([
&
](
auto
output
)
{
result
.
visit
([
&
](
auto
output
)
{
args
[
0
].
visit
([
&
](
auto
input
)
{
args
[
0
].
visit
([
&
](
auto
input
)
{
std
::
transform
(
input
.
begin
(),
input
.
end
(),
output
.
begin
(),
[](
auto
x
)
{
std
::
transform
(
input
.
begin
(),
input
.
end
(),
output
.
begin
(),
[](
auto
x
)
{
...
...
test/eval_test.cpp
View file @
d21778c6
...
@@ -8,7 +8,7 @@
...
@@ -8,7 +8,7 @@
struct
sum_op
struct
sum_op
{
{
std
::
string
name
()
const
{
return
"sum"
;
}
std
::
string
name
()
const
{
return
"sum"
;
}
rtg
::
argument
compute
(
std
::
vector
<
rtg
::
argument
>
args
)
const
rtg
::
argument
compute
(
rtg
::
shape
,
std
::
vector
<
rtg
::
argument
>
args
)
const
{
{
rtg
::
argument
result
;
rtg
::
argument
result
;
if
(
args
.
size
()
!=
2
)
if
(
args
.
size
()
!=
2
)
...
@@ -37,7 +37,7 @@ struct sum_op
...
@@ -37,7 +37,7 @@ struct sum_op
struct
minus_op
struct
minus_op
{
{
std
::
string
name
()
const
{
return
"minus"
;
}
std
::
string
name
()
const
{
return
"minus"
;
}
rtg
::
argument
compute
(
std
::
vector
<
rtg
::
argument
>
args
)
const
rtg
::
argument
compute
(
rtg
::
shape
,
std
::
vector
<
rtg
::
argument
>
args
)
const
{
{
rtg
::
argument
result
;
rtg
::
argument
result
;
if
(
args
.
size
()
!=
2
)
if
(
args
.
size
()
!=
2
)
...
...
test/operation.cpp
View file @
d21778c6
...
@@ -9,7 +9,7 @@ struct simple_operation
...
@@ -9,7 +9,7 @@ struct simple_operation
int
data
=
1
;
int
data
=
1
;
std
::
string
name
()
const
{
return
"simple"
;
}
std
::
string
name
()
const
{
return
"simple"
;
}
rtg
::
shape
compute_shape
(
std
::
vector
<
rtg
::
shape
>
)
const
{
RTG_THROW
(
"not computable"
);
}
rtg
::
shape
compute_shape
(
std
::
vector
<
rtg
::
shape
>
)
const
{
RTG_THROW
(
"not computable"
);
}
rtg
::
argument
compute
(
std
::
vector
<
rtg
::
argument
>
)
const
{
RTG_THROW
(
"not computable"
);
}
rtg
::
argument
compute
(
rtg
::
shape
,
std
::
vector
<
rtg
::
argument
>
)
const
{
RTG_THROW
(
"not computable"
);
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
simple_operation
&
op
)
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
simple_operation
&
op
)
{
{
os
<<
"["
<<
op
.
name
()
<<
"]"
;
os
<<
"["
<<
op
.
name
()
<<
"]"
;
...
@@ -21,7 +21,7 @@ struct simple_operation_no_print
...
@@ -21,7 +21,7 @@ struct simple_operation_no_print
{
{
std
::
string
name
()
const
{
return
"simple"
;
}
std
::
string
name
()
const
{
return
"simple"
;
}
rtg
::
shape
compute_shape
(
std
::
vector
<
rtg
::
shape
>
)
const
{
RTG_THROW
(
"not computable"
);
}
rtg
::
shape
compute_shape
(
std
::
vector
<
rtg
::
shape
>
)
const
{
RTG_THROW
(
"not computable"
);
}
rtg
::
argument
compute
(
std
::
vector
<
rtg
::
argument
>
)
const
{
RTG_THROW
(
"not computable"
);
}
rtg
::
argument
compute
(
rtg
::
shape
,
std
::
vector
<
rtg
::
argument
>
)
const
{
RTG_THROW
(
"not computable"
);
}
};
};
void
operation_copy_test
()
void
operation_copy_test
()
...
...
tools/include/operation.hpp
View file @
d21778c6
...
@@ -25,7 +25,7 @@ auto operator<<(std::ostream& os, const T& x) -> decltype(os << x.name())
...
@@ -25,7 +25,7 @@ auto operator<<(std::ostream& os, const T& x) -> decltype(os << x.name())
interface
(
'
operation
'
,
interface
(
'
operation
'
,
virtual
(
'
name
'
,
returns
=
'
std
::
string
'
,
const
=
True
),
virtual
(
'
name
'
,
returns
=
'
std
::
string
'
,
const
=
True
),
virtual
(
'
compute_shape
'
,
returns
=
'
shape
'
,
input
=
'
std
::
vector
<
shape
>
'
,
const
=
True
),
virtual
(
'
compute_shape
'
,
returns
=
'
shape
'
,
input
=
'
std
::
vector
<
shape
>
'
,
const
=
True
),
virtual
(
'
compute
'
,
returns
=
'
argument
'
,
input
=
'
std
::
vector
<
argument
>
'
,
const
=
True
),
virtual
(
'
compute
'
,
returns
=
'
argument
'
,
output
=
'
shape
'
,
input
=
'
std
::
vector
<
argument
>
'
,
const
=
True
),
friend
(
'
operator
<<
'
,
returns
=
'
std
::
ostream
&
'
,
os
=
'
std
::
ostream
&
'
,
op
=
'
const
operation
&
'
,
using
=
'
rtg
::
operation_stream
::
operator
<<
'
)
friend
(
'
operator
<<
'
,
returns
=
'
std
::
ostream
&
'
,
os
=
'
std
::
ostream
&
'
,
op
=
'
const
operation
&
'
,
using
=
'
rtg
::
operation_stream
::
operator
<<
'
)
)
)
%>
%>
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment