Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
d21778c6
Commit
d21778c6
authored
May 21, 2018
by
Paul
Browse files
Add shape param to compute
parent
0f318650
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
27 additions
and
28 deletions
+27
-28
src/include/rtg/builtin.hpp
src/include/rtg/builtin.hpp
+2
-2
src/include/rtg/operation.hpp
src/include/rtg/operation.hpp
+9
-9
src/include/rtg/operators.hpp
src/include/rtg/operators.hpp
+5
-5
src/onnx/read_onnx.cpp
src/onnx/read_onnx.cpp
+1
-1
src/program.cpp
src/program.cpp
+1
-1
src/targets/cpu/cpu_target.cpp
src/targets/cpu/cpu_target.cpp
+4
-5
test/eval_test.cpp
test/eval_test.cpp
+2
-2
test/operation.cpp
test/operation.cpp
+2
-2
tools/include/operation.hpp
tools/include/operation.hpp
+1
-1
No files found.
src/include/rtg/builtin.hpp
View file @
d21778c6
...
@@ -12,7 +12,7 @@ struct literal
...
@@ -12,7 +12,7 @@ struct literal
{
{
std
::
string
name
()
const
{
return
"@literal"
;
}
std
::
string
name
()
const
{
return
"@literal"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
)
const
{
RTG_THROW
(
"builtin"
);
}
shape
compute_shape
(
std
::
vector
<
shape
>
)
const
{
RTG_THROW
(
"builtin"
);
}
argument
compute
(
std
::
vector
<
argument
>
)
const
{
RTG_THROW
(
"builtin"
);
}
argument
compute
(
shape
,
std
::
vector
<
argument
>
)
const
{
RTG_THROW
(
"builtin"
);
}
};
};
struct
param
struct
param
...
@@ -20,7 +20,7 @@ struct param
...
@@ -20,7 +20,7 @@ struct param
std
::
string
parameter
;
std
::
string
parameter
;
std
::
string
name
()
const
{
return
"@param"
;
}
std
::
string
name
()
const
{
return
"@param"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
)
const
{
RTG_THROW
(
"builtin"
);
}
shape
compute_shape
(
std
::
vector
<
shape
>
)
const
{
RTG_THROW
(
"builtin"
);
}
argument
compute
(
std
::
vector
<
argument
>
)
const
{
RTG_THROW
(
"builtin"
);
}
argument
compute
(
shape
,
std
::
vector
<
argument
>
)
const
{
RTG_THROW
(
"builtin"
);
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
param
&
op
)
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
param
&
op
)
{
{
os
<<
op
.
name
()
<<
":"
<<
op
.
parameter
;
os
<<
op
.
name
()
<<
":"
<<
op
.
parameter
;
...
...
src/include/rtg/operation.hpp
View file @
d21778c6
...
@@ -28,7 +28,7 @@ auto operator<<(std::ostream& os, const T& x) -> decltype(os << x.name())
...
@@ -28,7 +28,7 @@ auto operator<<(std::ostream& os, const T& x) -> decltype(os << x.name())
* {
* {
* std::string name() const;
* std::string name() const;
* shape compute_shape(std::vector<shape> input) const;
* shape compute_shape(std::vector<shape> input) const;
* argument compute(std::vector<argument> input) const;
* argument compute(
shape output,
std::vector<argument> input) const;
* friend std::ostream & operator<<(std::ostream & os,const operation & op) ;
* friend std::ostream & operator<<(std::ostream & os,const operation & op) ;
* };
* };
*
*
...
@@ -95,10 +95,10 @@ struct operation
...
@@ -95,10 +95,10 @@ struct operation
return
(
*
this
).
private_detail_te_get_handle
().
compute_shape
(
std
::
move
(
input
));
return
(
*
this
).
private_detail_te_get_handle
().
compute_shape
(
std
::
move
(
input
));
}
}
argument
compute
(
std
::
vector
<
argument
>
input
)
const
argument
compute
(
shape
output
,
std
::
vector
<
argument
>
input
)
const
{
{
assert
((
*
this
).
private_detail_te_handle_mem_var
);
assert
((
*
this
).
private_detail_te_handle_mem_var
);
return
(
*
this
).
private_detail_te_get_handle
().
compute
(
std
::
move
(
input
));
return
(
*
this
).
private_detail_te_get_handle
().
compute
(
std
::
move
(
output
),
std
::
move
(
input
));
}
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
operation
&
op
)
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
operation
&
op
)
...
@@ -114,10 +114,10 @@ struct operation
...
@@ -114,10 +114,10 @@ struct operation
virtual
std
::
shared_ptr
<
private_detail_te_handle_base_type
>
clone
()
const
=
0
;
virtual
std
::
shared_ptr
<
private_detail_te_handle_base_type
>
clone
()
const
=
0
;
virtual
const
std
::
type_info
&
type
()
const
=
0
;
virtual
const
std
::
type_info
&
type
()
const
=
0
;
virtual
std
::
string
name
()
const
=
0
;
virtual
std
::
string
name
()
const
=
0
;
virtual
shape
compute_shape
(
std
::
vector
<
shape
>
input
)
const
=
0
;
virtual
shape
compute_shape
(
std
::
vector
<
shape
>
input
)
const
=
0
;
virtual
argument
compute
(
std
::
vector
<
argument
>
input
)
const
=
0
;
virtual
argument
compute
(
shape
output
,
std
::
vector
<
argument
>
input
)
const
=
0
;
virtual
std
::
ostream
&
operator_shift_left
(
std
::
ostream
&
os
)
const
=
0
;
virtual
std
::
ostream
&
operator_shift_left
(
std
::
ostream
&
os
)
const
=
0
;
};
};
template
<
typename
PrivateDetailTypeErasedT
>
template
<
typename
PrivateDetailTypeErasedT
>
...
@@ -156,10 +156,10 @@ struct operation
...
@@ -156,10 +156,10 @@ struct operation
return
private_detail_te_value
.
compute_shape
(
std
::
move
(
input
));
return
private_detail_te_value
.
compute_shape
(
std
::
move
(
input
));
}
}
argument
compute
(
std
::
vector
<
argument
>
input
)
const
override
argument
compute
(
shape
output
,
std
::
vector
<
argument
>
input
)
const
override
{
{
return
private_detail_te_value
.
compute
(
std
::
move
(
input
));
return
private_detail_te_value
.
compute
(
std
::
move
(
output
),
std
::
move
(
input
));
}
}
std
::
ostream
&
operator_shift_left
(
std
::
ostream
&
os
)
const
override
std
::
ostream
&
operator_shift_left
(
std
::
ostream
&
os
)
const
override
...
...
src/include/rtg/operators.hpp
View file @
d21778c6
...
@@ -10,7 +10,7 @@ namespace rtg {
...
@@ -10,7 +10,7 @@ namespace rtg {
struct
not_computable
struct
not_computable
{
{
argument
compute
(
std
::
vector
<
argument
>
)
const
{
RTG_THROW
(
"not computable"
);
}
argument
compute
(
shape
,
std
::
vector
<
argument
>
)
const
{
RTG_THROW
(
"not computable"
);
}
};
};
struct
convolution
struct
convolution
...
@@ -52,7 +52,7 @@ struct convolution
...
@@ -52,7 +52,7 @@ struct convolution
}};
}};
}
}
argument
compute
(
std
::
vector
<
argument
>
)
const
{
RTG_THROW
(
"not computable"
);
}
argument
compute
(
shape
,
std
::
vector
<
argument
>
)
const
{
RTG_THROW
(
"not computable"
);
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
convolution
&
op
)
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
convolution
&
op
)
{
{
...
@@ -98,7 +98,7 @@ struct pooling
...
@@ -98,7 +98,7 @@ struct pooling
}};
}};
}
}
argument
compute
(
std
::
vector
<
argument
>
)
const
{
RTG_THROW
(
"not computable"
);
}
argument
compute
(
shape
,
std
::
vector
<
argument
>
)
const
{
RTG_THROW
(
"not computable"
);
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
pooling
&
op
)
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
pooling
&
op
)
{
{
...
@@ -122,7 +122,7 @@ struct activation
...
@@ -122,7 +122,7 @@ struct activation
return
inputs
.
front
();
return
inputs
.
front
();
}
}
argument
compute
(
std
::
vector
<
argument
>
)
const
{
RTG_THROW
(
"not computable"
);
}
argument
compute
(
shape
,
std
::
vector
<
argument
>
)
const
{
RTG_THROW
(
"not computable"
);
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
activation
&
op
)
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
activation
&
op
)
{
{
os
<<
op
.
name
()
<<
":"
<<
op
.
mode
;
os
<<
op
.
name
()
<<
":"
<<
op
.
mode
;
...
@@ -153,7 +153,7 @@ struct reshape
...
@@ -153,7 +153,7 @@ struct reshape
return
{
inputs
.
front
().
type
(),
rdims
};
return
{
inputs
.
front
().
type
(),
rdims
};
}
}
argument
compute
(
std
::
vector
<
argument
>
)
const
{
RTG_THROW
(
"not computable"
);
}
argument
compute
(
shape
,
std
::
vector
<
argument
>
)
const
{
RTG_THROW
(
"not computable"
);
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
reshape
&
op
)
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
reshape
&
op
)
{
{
...
...
src/onnx/read_onnx.cpp
View file @
d21778c6
...
@@ -22,7 +22,7 @@ struct unknown
...
@@ -22,7 +22,7 @@ struct unknown
else
else
return
input
.
front
();
return
input
.
front
();
}
}
rtg
::
argument
compute
(
std
::
vector
<
rtg
::
argument
>
)
const
{
RTG_THROW
(
"not computable"
);
}
rtg
::
argument
compute
(
rtg
::
shape
,
std
::
vector
<
rtg
::
argument
>
)
const
{
RTG_THROW
(
"not computable"
);
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
unknown
&
x
)
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
unknown
&
x
)
{
{
os
<<
x
.
name
();
os
<<
x
.
name
();
...
...
src/program.cpp
View file @
d21778c6
...
@@ -109,7 +109,7 @@ literal program::eval(std::unordered_map<std::string, argument> params) const
...
@@ -109,7 +109,7 @@ literal program::eval(std::unordered_map<std::string, argument> params) const
ins
.
arguments
.
end
(),
ins
.
arguments
.
end
(),
values
.
begin
(),
values
.
begin
(),
[
&
](
instruction_ref
i
)
{
return
results
.
at
(
std
::
addressof
(
*
i
));
});
[
&
](
instruction_ref
i
)
{
return
results
.
at
(
std
::
addressof
(
*
i
));
});
result
=
ins
.
op
.
compute
(
values
);
result
=
ins
.
op
.
compute
(
ins
.
result
,
values
);
}
}
results
.
emplace
(
std
::
addressof
(
ins
),
result
);
results
.
emplace
(
std
::
addressof
(
ins
),
result
);
}
}
...
...
src/targets/cpu/cpu_target.cpp
View file @
d21778c6
...
@@ -13,10 +13,9 @@ struct cpu_convolution
...
@@ -13,10 +13,9 @@ struct cpu_convolution
std
::
string
name
()
const
{
return
"cpu::convolution"
;
}
std
::
string
name
()
const
{
return
"cpu::convolution"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
return
op
.
compute_shape
(
inputs
);
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
return
op
.
compute_shape
(
inputs
);
}
argument
compute
(
std
::
vector
<
argument
>
args
)
const
argument
compute
(
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
{
shape
output_shape
=
compute_shape
({
args
[
0
].
get_shape
(),
args
[
1
].
get_shape
()});
argument
result
{
output_shape
};
argument
result
{
compute_shape
({
args
[
0
].
get_shape
(),
args
[
1
].
get_shape
()})};
visit_all
(
result
,
args
[
0
],
args
[
1
])([
&
](
auto
output
,
auto
input
,
auto
weights
)
{
visit_all
(
result
,
args
[
0
],
args
[
1
])([
&
](
auto
output
,
auto
input
,
auto
weights
)
{
auto
in_n
=
input
.
get_shape
().
lens
()[
0
];
auto
in_n
=
input
.
get_shape
().
lens
()[
0
];
auto
in_c
=
input
.
get_shape
().
lens
()[
1
];
auto
in_c
=
input
.
get_shape
().
lens
()[
1
];
...
@@ -53,9 +52,9 @@ struct relu
...
@@ -53,9 +52,9 @@ struct relu
std
::
string
name
()
const
{
return
"cpu::relu"
;
}
std
::
string
name
()
const
{
return
"cpu::relu"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
return
inputs
.
front
();
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
return
inputs
.
front
();
}
argument
compute
(
std
::
vector
<
argument
>
args
)
const
argument
compute
(
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
{
argument
result
{
args
[
0
].
ge
t_shape
()
};
argument
result
{
outpu
t_shape
};
result
.
visit
([
&
](
auto
output
)
{
result
.
visit
([
&
](
auto
output
)
{
args
[
0
].
visit
([
&
](
auto
input
)
{
args
[
0
].
visit
([
&
](
auto
input
)
{
std
::
transform
(
input
.
begin
(),
input
.
end
(),
output
.
begin
(),
[](
auto
x
)
{
std
::
transform
(
input
.
begin
(),
input
.
end
(),
output
.
begin
(),
[](
auto
x
)
{
...
...
test/eval_test.cpp
View file @
d21778c6
...
@@ -8,7 +8,7 @@
...
@@ -8,7 +8,7 @@
struct
sum_op
struct
sum_op
{
{
std
::
string
name
()
const
{
return
"sum"
;
}
std
::
string
name
()
const
{
return
"sum"
;
}
rtg
::
argument
compute
(
std
::
vector
<
rtg
::
argument
>
args
)
const
rtg
::
argument
compute
(
rtg
::
shape
,
std
::
vector
<
rtg
::
argument
>
args
)
const
{
{
rtg
::
argument
result
;
rtg
::
argument
result
;
if
(
args
.
size
()
!=
2
)
if
(
args
.
size
()
!=
2
)
...
@@ -37,7 +37,7 @@ struct sum_op
...
@@ -37,7 +37,7 @@ struct sum_op
struct
minus_op
struct
minus_op
{
{
std
::
string
name
()
const
{
return
"minus"
;
}
std
::
string
name
()
const
{
return
"minus"
;
}
rtg
::
argument
compute
(
std
::
vector
<
rtg
::
argument
>
args
)
const
rtg
::
argument
compute
(
rtg
::
shape
,
std
::
vector
<
rtg
::
argument
>
args
)
const
{
{
rtg
::
argument
result
;
rtg
::
argument
result
;
if
(
args
.
size
()
!=
2
)
if
(
args
.
size
()
!=
2
)
...
...
test/operation.cpp
View file @
d21778c6
...
@@ -9,7 +9,7 @@ struct simple_operation
...
@@ -9,7 +9,7 @@ struct simple_operation
int
data
=
1
;
int
data
=
1
;
std
::
string
name
()
const
{
return
"simple"
;
}
std
::
string
name
()
const
{
return
"simple"
;
}
rtg
::
shape
compute_shape
(
std
::
vector
<
rtg
::
shape
>
)
const
{
RTG_THROW
(
"not computable"
);
}
rtg
::
shape
compute_shape
(
std
::
vector
<
rtg
::
shape
>
)
const
{
RTG_THROW
(
"not computable"
);
}
rtg
::
argument
compute
(
std
::
vector
<
rtg
::
argument
>
)
const
{
RTG_THROW
(
"not computable"
);
}
rtg
::
argument
compute
(
rtg
::
shape
,
std
::
vector
<
rtg
::
argument
>
)
const
{
RTG_THROW
(
"not computable"
);
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
simple_operation
&
op
)
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
simple_operation
&
op
)
{
{
os
<<
"["
<<
op
.
name
()
<<
"]"
;
os
<<
"["
<<
op
.
name
()
<<
"]"
;
...
@@ -21,7 +21,7 @@ struct simple_operation_no_print
...
@@ -21,7 +21,7 @@ struct simple_operation_no_print
{
{
std
::
string
name
()
const
{
return
"simple"
;
}
std
::
string
name
()
const
{
return
"simple"
;
}
rtg
::
shape
compute_shape
(
std
::
vector
<
rtg
::
shape
>
)
const
{
RTG_THROW
(
"not computable"
);
}
rtg
::
shape
compute_shape
(
std
::
vector
<
rtg
::
shape
>
)
const
{
RTG_THROW
(
"not computable"
);
}
rtg
::
argument
compute
(
std
::
vector
<
rtg
::
argument
>
)
const
{
RTG_THROW
(
"not computable"
);
}
rtg
::
argument
compute
(
rtg
::
shape
,
std
::
vector
<
rtg
::
argument
>
)
const
{
RTG_THROW
(
"not computable"
);
}
};
};
void
operation_copy_test
()
void
operation_copy_test
()
...
...
tools/include/operation.hpp
View file @
d21778c6
...
@@ -25,7 +25,7 @@ auto operator<<(std::ostream& os, const T& x) -> decltype(os << x.name())
...
@@ -25,7 +25,7 @@ auto operator<<(std::ostream& os, const T& x) -> decltype(os << x.name())
interface
(
'
operation
'
,
interface
(
'
operation
'
,
virtual
(
'
name
'
,
returns
=
'
std
::
string
'
,
const
=
True
),
virtual
(
'
name
'
,
returns
=
'
std
::
string
'
,
const
=
True
),
virtual
(
'
compute_shape
'
,
returns
=
'
shape
'
,
input
=
'
std
::
vector
<
shape
>
'
,
const
=
True
),
virtual
(
'
compute_shape
'
,
returns
=
'
shape
'
,
input
=
'
std
::
vector
<
shape
>
'
,
const
=
True
),
virtual
(
'
compute
'
,
returns
=
'
argument
'
,
input
=
'
std
::
vector
<
argument
>
'
,
const
=
True
),
virtual
(
'
compute
'
,
returns
=
'
argument
'
,
output
=
'
shape
'
,
input
=
'
std
::
vector
<
argument
>
'
,
const
=
True
),
friend
(
'
operator
<<
'
,
returns
=
'
std
::
ostream
&
'
,
os
=
'
std
::
ostream
&
'
,
op
=
'
const
operation
&
'
,
using
=
'
rtg
::
operation_stream
::
operator
<<
'
)
friend
(
'
operator
<<
'
,
returns
=
'
std
::
ostream
&
'
,
os
=
'
std
::
ostream
&
'
,
op
=
'
const
operation
&
'
,
using
=
'
rtg
::
operation_stream
::
operator
<<
'
)
)
)
%>
%>
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment