Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
a3bdb08f
Unverified
Commit
a3bdb08f
authored
Jan 08, 2019
by
Paul Fultz II
Committed by
GitHub
Jan 08, 2019
Browse files
Merge pull request #147 from ROCmSoftwarePlatform/const-eval
Add const evaluation of instructions
parents
f98b01b3
c5aa614f
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
356 additions
and
42 deletions
+356
-42
cppcheck.rules
cppcheck.rules
+1
-1
src/include/migraphx/instruction.hpp
src/include/migraphx/instruction.hpp
+2
-0
src/include/migraphx/operation.hpp
src/include/migraphx/operation.hpp
+98
-5
src/include/migraphx/operators.hpp
src/include/migraphx/operators.hpp
+34
-16
src/instruction.cpp
src/instruction.cpp
+21
-0
src/targets/cpu/lowering.cpp
src/targets/cpu/lowering.cpp
+1
-18
test/const_eval_test.cpp
test/const_eval_test.cpp
+125
-0
tools/include/operation.hpp
tools/include/operation.hpp
+74
-2
No files found.
cppcheck.rules
View file @
a3bdb08f
...
@@ -74,7 +74,7 @@
...
@@ -74,7 +74,7 @@
</message>
</message>
</rule>
</rule>
<rule>
<rule>
<pattern>
(fclose|free|hipFree|hipHostFree|hipFreeArray|hipMemFree|hipStreamDestroy|hipEventDestroy|hipArrayDestroy|hipCtxDestroy|hipDestroyTextureObject|hipDestroySurfaceObject) \(
</pattern>
<pattern>
\\W
(fclose|free|hipFree|hipHostFree|hipFreeArray|hipMemFree|hipStreamDestroy|hipEventDestroy|hipArrayDestroy|hipCtxDestroy|hipDestroyTextureObject|hipDestroySurfaceObject) \(
</pattern>
<message>
<message>
<id>
useManagePointer
</id>
<id>
useManagePointer
</id>
<severity>
style
</severity>
<severity>
style
</severity>
...
...
src/include/migraphx/instruction.hpp
View file @
a3bdb08f
...
@@ -71,6 +71,8 @@ struct instruction
...
@@ -71,6 +71,8 @@ struct instruction
static
void
static
void
replace
(
instruction_ref
ins
,
operation
o
,
const
shape
&
r
,
std
::
vector
<
instruction_ref
>
args
);
replace
(
instruction_ref
ins
,
operation
o
,
const
shape
&
r
,
std
::
vector
<
instruction_ref
>
args
);
argument
eval
()
const
;
static
instruction_ref
get_output_alias
(
instruction_ref
ins
);
static
instruction_ref
get_output_alias
(
instruction_ref
ins
);
private:
private:
...
...
src/include/migraphx/operation.hpp
View file @
a3bdb08f
...
@@ -53,6 +53,9 @@ struct operation
...
@@ -53,6 +53,9 @@ struct operation
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
operation
&
op
);
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
operation
&
op
);
};
};
/// Returns true if operation does not require a context to run compute
bool
is_context_free
(
const
operation
&
x
);
#else
#else
namespace
operation_stream
{
namespace
operation_stream
{
...
@@ -89,7 +92,7 @@ auto operator==(const T& x, const U& y) -> decltype(x.name() == y.name())
...
@@ -89,7 +92,7 @@ auto operator==(const T& x, const U& y) -> decltype(x.name() == y.name())
}
// namespace operation_equal
}
// namespace operation_equal
template
<
class
T
>
template
<
class
T
>
auto
compute_op
(
rank
<
1
>
,
auto
compute_op
(
rank
<
2
>
,
const
T
&
x
,
const
T
&
x
,
context
&
ctx
,
context
&
ctx
,
const
shape
&
output_shape
,
const
shape
&
output_shape
,
...
@@ -99,6 +102,14 @@ auto compute_op(rank<1>,
...
@@ -99,6 +102,14 @@ auto compute_op(rank<1>,
return
x
.
compute
(
auto_any_cast
(
ctx
),
output_shape
,
input
);
return
x
.
compute
(
auto_any_cast
(
ctx
),
output_shape
,
input
);
}
}
template
<
class
T
>
auto
compute_op
(
rank
<
1
>
,
const
T
&
x
,
context
&
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
input
)
->
decltype
(
x
.
compute
(
output_shape
,
input
))
{
return
x
.
compute
(
output_shape
,
input
);
}
template
<
class
T
>
template
<
class
T
>
argument
compute_op
(
rank
<
0
>
,
const
T
&
x
,
context
&
,
const
shape
&
,
const
std
::
vector
<
argument
>&
)
argument
compute_op
(
rank
<
0
>
,
const
T
&
x
,
context
&
,
const
shape
&
,
const
std
::
vector
<
argument
>&
)
{
{
...
@@ -110,7 +121,53 @@ template <class T>
...
@@ -110,7 +121,53 @@ template <class T>
argument
argument
compute_op
(
const
T
&
x
,
context
&
ctx
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
input
)
compute_op
(
const
T
&
x
,
context
&
ctx
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
input
)
{
{
return
compute_op
(
rank
<
1
>
{},
x
,
ctx
,
output_shape
,
input
);
return
compute_op
(
rank
<
2
>
{},
x
,
ctx
,
output_shape
,
input
);
}
template
<
class
T
>
auto
compute_op
(
rank
<
2
>
,
const
T
&
x
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
input
)
->
decltype
(
x
.
compute
(
output_shape
,
input
))
{
return
x
.
compute
(
output_shape
,
input
);
}
template
<
class
T
>
auto
compute_op
(
rank
<
1
>
,
const
T
&
x
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
input
)
->
decltype
(
x
.
compute
(
auto_any_cast
(
std
::
declval
<
context
&>
()),
output_shape
,
input
))
{
std
::
string
name
=
x
.
name
();
MIGRAPHX_THROW
(
"Not computable without a context: "
+
name
);
}
template
<
class
T
>
argument
compute_op
(
rank
<
0
>
,
const
T
&
x
,
const
shape
&
,
const
std
::
vector
<
argument
>&
)
{
std
::
string
name
=
x
.
name
();
MIGRAPHX_THROW
(
"Not computable: "
+
name
);
}
template
<
class
T
>
argument
compute_op
(
const
T
&
x
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
input
)
{
return
compute_op
(
rank
<
2
>
{},
x
,
output_shape
,
input
);
}
template
<
class
T
>
auto
is_context_free_op
(
rank
<
1
>
,
const
T
&
x
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
input
)
->
decltype
(
x
.
compute
(
output_shape
,
input
),
std
::
true_type
{});
template
<
class
T
>
auto
is_context_free_op
(
rank
<
0
>
,
const
T
&
,
const
shape
&
,
const
std
::
vector
<
argument
>&
)
->
std
::
false_type
;
template
<
class
T
>
auto
is_context_free_op
(
const
T
&
x
)
->
decltype
(
is_context_free_op
(
rank
<
1
>
{},
x
,
std
::
declval
<
const
shape
&>
(),
std
::
declval
<
std
::
vector
<
argument
>>
()))
{
return
{};
}
}
template
<
class
T
>
template
<
class
T
>
...
@@ -138,9 +195,11 @@ int output_alias_op(const T& x, const std::vector<shape>& shapes)
...
@@ -138,9 +195,11 @@ int output_alias_op(const T& x, const std::vector<shape>& shapes)
* struct operation
* struct operation
* {
* {
* std::string name() const;
* std::string name() const;
* bool is_context_free() const;
* int output_alias(const std::vector<shape>& input) const;
* int output_alias(const std::vector<shape>& input) const;
* shape compute_shape(const std::vector<shape>& input) const;
* shape compute_shape(const std::vector<shape>& input) const;
* argument compute(context& ctx,const shape& output,const std::vector<argument>& input) const;
* argument compute(context& ctx,const shape& output,const std::vector<argument>& input) const;
* argument compute(const shape& output,const std::vector<argument>& input) const;
* friend std::ostream & operator<<(std::ostream & os,const operation & op) ;
* friend std::ostream & operator<<(std::ostream & os,const operation & op) ;
* friend bool operator==(const operation & x,const operation & y) ;
* friend bool operator==(const operation & x,const operation & y) ;
* };
* };
...
@@ -210,6 +269,12 @@ struct operation
...
@@ -210,6 +269,12 @@ struct operation
return
(
*
this
).
private_detail_te_get_handle
().
name
();
return
(
*
this
).
private_detail_te_get_handle
().
name
();
}
}
bool
is_context_free
()
const
{
assert
((
*
this
).
private_detail_te_handle_mem_var
);
return
(
*
this
).
private_detail_te_get_handle
().
is_context_free
();
}
int
output_alias
(
const
std
::
vector
<
shape
>&
input
)
const
int
output_alias
(
const
std
::
vector
<
shape
>&
input
)
const
{
{
assert
((
*
this
).
private_detail_te_handle_mem_var
);
assert
((
*
this
).
private_detail_te_handle_mem_var
);
...
@@ -228,6 +293,12 @@ struct operation
...
@@ -228,6 +293,12 @@ struct operation
return
(
*
this
).
private_detail_te_get_handle
().
compute
(
ctx
,
output
,
input
);
return
(
*
this
).
private_detail_te_get_handle
().
compute
(
ctx
,
output
,
input
);
}
}
argument
compute
(
const
shape
&
output
,
const
std
::
vector
<
argument
>&
input
)
const
{
assert
((
*
this
).
private_detail_te_handle_mem_var
);
return
(
*
this
).
private_detail_te_get_handle
().
compute
(
output
,
input
);
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
operation
&
op
)
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
operation
&
op
)
{
{
assert
(
op
.
private_detail_te_handle_mem_var
);
assert
(
op
.
private_detail_te_handle_mem_var
);
...
@@ -248,12 +319,14 @@ struct operation
...
@@ -248,12 +319,14 @@ struct operation
virtual
const
std
::
type_info
&
type
()
const
=
0
;
virtual
const
std
::
type_info
&
type
()
const
=
0
;
virtual
std
::
string
name
()
const
=
0
;
virtual
std
::
string
name
()
const
=
0
;
virtual
bool
is_context_free
()
const
=
0
;
virtual
int
output_alias
(
const
std
::
vector
<
shape
>&
input
)
const
=
0
;
virtual
int
output_alias
(
const
std
::
vector
<
shape
>&
input
)
const
=
0
;
virtual
shape
compute_shape
(
const
std
::
vector
<
shape
>&
input
)
const
=
0
;
virtual
shape
compute_shape
(
const
std
::
vector
<
shape
>&
input
)
const
=
0
;
virtual
argument
virtual
argument
compute
(
context
&
ctx
,
const
shape
&
output
,
const
std
::
vector
<
argument
>&
input
)
const
=
0
;
compute
(
context
&
ctx
,
const
shape
&
output
,
const
std
::
vector
<
argument
>&
input
)
const
=
0
;
virtual
std
::
ostream
&
operator_shift_left
(
std
::
ostream
&
os
)
const
=
0
;
virtual
argument
compute
(
const
shape
&
output
,
const
std
::
vector
<
argument
>&
input
)
const
=
0
;
virtual
bool
operator
==
(
const
operation
&
y
)
const
=
0
;
virtual
std
::
ostream
&
operator_shift_left
(
std
::
ostream
&
os
)
const
=
0
;
virtual
bool
operator
==
(
const
operation
&
y
)
const
=
0
;
};
};
template
<
typename
PrivateDetailTypeErasedT
>
template
<
typename
PrivateDetailTypeErasedT
>
...
@@ -286,6 +359,12 @@ struct operation
...
@@ -286,6 +359,12 @@ struct operation
std
::
string
name
()
const
override
{
return
private_detail_te_value
.
name
();
}
std
::
string
name
()
const
override
{
return
private_detail_te_value
.
name
();
}
bool
is_context_free
()
const
override
{
return
is_context_free_op
(
private_detail_te_value
);
}
int
output_alias
(
const
std
::
vector
<
shape
>&
input
)
const
override
int
output_alias
(
const
std
::
vector
<
shape
>&
input
)
const
override
{
{
...
@@ -306,6 +385,12 @@ struct operation
...
@@ -306,6 +385,12 @@ struct operation
return
compute_op
(
private_detail_te_value
,
ctx
,
output
,
input
);
return
compute_op
(
private_detail_te_value
,
ctx
,
output
,
input
);
}
}
argument
compute
(
const
shape
&
output
,
const
std
::
vector
<
argument
>&
input
)
const
override
{
return
compute_op
(
private_detail_te_value
,
output
,
input
);
}
std
::
ostream
&
operator_shift_left
(
std
::
ostream
&
os
)
const
override
std
::
ostream
&
operator_shift_left
(
std
::
ostream
&
os
)
const
override
{
{
using
migraphx
::
operation_stream
::
operator
<<
;
using
migraphx
::
operation_stream
::
operator
<<
;
...
@@ -385,6 +470,14 @@ inline const ValueType& any_cast(const operation& x)
...
@@ -385,6 +470,14 @@ inline const ValueType& any_cast(const operation& x)
inline
bool
operator
!=
(
const
operation
&
x
,
const
operation
&
y
)
{
return
!
(
x
==
y
);
}
inline
bool
operator
!=
(
const
operation
&
x
,
const
operation
&
y
)
{
return
!
(
x
==
y
);
}
inline
bool
is_context_free
(
const
operation
&
op
)
{
return
op
.
is_context_free
();
}
template
<
class
T
>
bool
is_context_free
(
const
T
&
x
)
{
return
is_context_free_op
(
x
);
}
#endif
#endif
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/include/migraphx/operators.hpp
View file @
a3bdb08f
...
@@ -16,7 +16,7 @@ namespace op {
...
@@ -16,7 +16,7 @@ namespace op {
struct
not_computable
struct
not_computable
{
{
argument
compute
(
context
&
,
const
shape
&
,
const
std
::
vector
<
argument
>&
)
const
argument
compute
(
const
shape
&
,
const
std
::
vector
<
argument
>&
)
const
{
{
MIGRAPHX_THROW
(
"not computable"
);
MIGRAPHX_THROW
(
"not computable"
);
}
}
...
@@ -296,7 +296,7 @@ struct transpose
...
@@ -296,7 +296,7 @@ struct transpose
}
}
return
{
t
,
output_lens
,
output_strides
};
return
{
t
,
output_lens
,
output_strides
};
}
}
argument
compute
(
context
&
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
{
return
{
std
::
move
(
output_shape
),
std
::
move
(
args
.
front
().
data
)};
return
{
std
::
move
(
output_shape
),
std
::
move
(
args
.
front
().
data
)};
}
}
...
@@ -370,6 +370,27 @@ struct concat
...
@@ -370,6 +370,27 @@ struct concat
new_lens
[
axis
]
=
new_dim_axis
;
new_lens
[
axis
]
=
new_dim_axis
;
return
{
type
,
new_lens
};
return
{
type
,
new_lens
};
}
}
argument
compute
(
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
argument
result
{
output_shape
};
std
::
vector
<
std
::
size_t
>
coffsets
=
compute_offsets
(
output_shape
,
args
);
for
(
std
::
size_t
l
=
0
;
l
<
args
.
size
();
l
++
)
{
auto
argl
=
args
[
l
];
std
::
size_t
nelements
=
argl
.
get_shape
().
elements
();
visit_all
(
result
,
argl
)([
&
](
auto
output
,
auto
input
)
{
auto
slice_shape
=
shape
{
output_shape
.
type
(),
input
.
get_shape
().
lens
(),
output_shape
.
strides
()};
auto
slice
=
make_view
(
slice_shape
,
output
.
data
()
+
coffsets
[
l
]);
// cppcheck-suppress useStlAlgorithm
for
(
std
::
size_t
i
=
0
;
i
<
nelements
;
i
++
)
{
slice
[
i
]
=
input
[
i
];
}
});
}
return
result
;
}
int
output_alias
(
const
std
::
vector
<
shape
>&
)
const
{
return
0
;
}
int
output_alias
(
const
std
::
vector
<
shape
>&
)
const
{
return
0
;
}
};
};
...
@@ -437,7 +458,7 @@ struct slice
...
@@ -437,7 +458,7 @@ struct slice
}
}
return
shape
{
t
,
new_lens
,
old_strides
};
return
shape
{
t
,
new_lens
,
old_strides
};
}
}
argument
compute
(
context
&
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
{
auto
input
=
args
[
0
];
auto
input
=
args
[
0
];
auto
offset
=
compute_offset
(
input
.
get_shape
())
*
output_shape
.
type_size
();
auto
offset
=
compute_offset
(
input
.
get_shape
())
*
output_shape
.
type_size
();
...
@@ -487,7 +508,7 @@ struct squeeze
...
@@ -487,7 +508,7 @@ struct squeeze
}
}
return
shape
{
type
,
new_lens
};
return
shape
{
type
,
new_lens
};
}
}
argument
compute
(
context
&
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
{
return
{
std
::
move
(
output_shape
),
std
::
move
(
args
.
front
().
data
)};
return
{
std
::
move
(
output_shape
),
std
::
move
(
args
.
front
().
data
)};
}
}
...
@@ -526,7 +547,7 @@ struct unsqueeze
...
@@ -526,7 +547,7 @@ struct unsqueeze
}
}
return
shape
{
type
,
new_lens
};
return
shape
{
type
,
new_lens
};
}
}
argument
compute
(
context
&
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
{
return
{
std
::
move
(
output_shape
),
std
::
move
(
args
.
front
().
data
)};
return
{
std
::
move
(
output_shape
),
std
::
move
(
args
.
front
().
data
)};
}
}
...
@@ -578,7 +599,7 @@ struct reshape
...
@@ -578,7 +599,7 @@ struct reshape
MIGRAPHX_THROW
(
"Wrong number of elements for reshape"
);
MIGRAPHX_THROW
(
"Wrong number of elements for reshape"
);
return
s
;
return
s
;
}
}
argument
compute
(
context
&
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
{
return
{
std
::
move
(
output_shape
),
std
::
move
(
args
.
front
().
data
)};
return
{
std
::
move
(
output_shape
),
std
::
move
(
args
.
front
().
data
)};
}
}
...
@@ -624,7 +645,7 @@ struct identity
...
@@ -624,7 +645,7 @@ struct identity
{
{
std
::
string
name
()
const
{
return
"identity"
;
}
std
::
string
name
()
const
{
return
"identity"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
return
inputs
.
at
(
0
);
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
return
inputs
.
at
(
0
);
}
argument
compute
(
context
&
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
{
return
{
std
::
move
(
output_shape
),
std
::
move
(
args
.
at
(
0
).
data
)};
return
{
std
::
move
(
output_shape
),
std
::
move
(
args
.
at
(
0
).
data
)};
}
}
...
@@ -742,7 +763,7 @@ struct flatten
...
@@ -742,7 +763,7 @@ struct flatten
std
::
accumulate
(
lens
.
begin
()
+
axis
,
lens
.
end
(),
std
::
size_t
{
1
},
std
::
multiplies
<>
{});
std
::
accumulate
(
lens
.
begin
()
+
axis
,
lens
.
end
(),
std
::
size_t
{
1
},
std
::
multiplies
<>
{});
return
{
inputs
.
at
(
0
).
type
(),
{
x
,
y
}};
return
{
inputs
.
at
(
0
).
type
(),
{
x
,
y
}};
}
}
argument
compute
(
context
&
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
{
return
{
std
::
move
(
output_shape
),
std
::
move
(
args
.
front
().
data
)};
return
{
std
::
move
(
output_shape
),
std
::
move
(
args
.
front
().
data
)};
}
}
...
@@ -794,7 +815,7 @@ struct broadcast
...
@@ -794,7 +815,7 @@ struct broadcast
return
{
t
,
broadcast_shape
.
lens
(),
std
::
move
(
bcast_strides
)};
return
{
t
,
broadcast_shape
.
lens
(),
std
::
move
(
bcast_strides
)};
}
}
}
}
argument
compute
(
context
&
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
{
return
{
std
::
move
(
output_shape
),
std
::
move
(
args
.
at
(
0
).
data
)};
return
{
std
::
move
(
output_shape
),
std
::
move
(
args
.
at
(
0
).
data
)};
}
}
...
@@ -836,7 +857,7 @@ struct multibroadcast
...
@@ -836,7 +857,7 @@ struct multibroadcast
}
}
return
{
t
,
output_lens
,
bcast_strides
};
return
{
t
,
output_lens
,
bcast_strides
};
}
}
argument
compute
(
context
&
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
{
return
{
std
::
move
(
output_shape
),
std
::
move
(
args
.
at
(
0
).
data
)};
return
{
std
::
move
(
output_shape
),
std
::
move
(
args
.
at
(
0
).
data
)};
}
}
...
@@ -858,7 +879,7 @@ struct scalar
...
@@ -858,7 +879,7 @@ struct scalar
return
{
t
,
scalar_bcast
.
lens
(),
strides
};
return
{
t
,
scalar_bcast
.
lens
(),
strides
};
}
}
argument
compute
(
context
&
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
{
return
{
std
::
move
(
output_shape
),
std
::
move
(
args
.
at
(
0
).
data
)};
return
{
std
::
move
(
output_shape
),
std
::
move
(
args
.
at
(
0
).
data
)};
}
}
...
@@ -923,7 +944,7 @@ struct load
...
@@ -923,7 +944,7 @@ struct load
check_shapes
{
inputs
}.
has
(
1
);
check_shapes
{
inputs
}.
has
(
1
);
return
s
;
return
s
;
}
}
argument
compute
(
context
&
,
const
shape
&
,
const
std
::
vector
<
argument
>&
args
)
const
argument
compute
(
const
shape
&
,
const
std
::
vector
<
argument
>&
args
)
const
{
{
return
{
s
,
args
[
0
].
data
()
+
offset
};
return
{
s
,
args
[
0
].
data
()
+
offset
};
}
}
...
@@ -946,10 +967,7 @@ struct outline
...
@@ -946,10 +967,7 @@ struct outline
check_shapes
{
inputs
,
*
this
}.
has
(
0
);
check_shapes
{
inputs
,
*
this
}.
has
(
0
);
return
s
;
return
s
;
}
}
argument
compute
(
context
&
,
const
shape
&
,
const
std
::
vector
<
argument
>&
)
const
argument
compute
(
const
shape
&
,
const
std
::
vector
<
argument
>&
)
const
{
return
{
s
,
nullptr
};
}
{
return
{
s
,
nullptr
};
}
};
};
}
// namespace op
}
// namespace op
...
...
src/instruction.cpp
View file @
a3bdb08f
...
@@ -170,6 +170,27 @@ std::vector<shape> compute_shapes(const std::vector<instruction_ref>& args)
...
@@ -170,6 +170,27 @@ std::vector<shape> compute_shapes(const std::vector<instruction_ref>& args)
return
shapes
;
return
shapes
;
}
}
argument
instruction
::
eval
()
const
{
if
(
op
.
name
()
==
"@literal"
)
{
return
this
->
get_literal
().
get_argument
();
}
if
(
is_context_free
(
op
))
{
std
::
vector
<
argument
>
args
;
for
(
auto
&&
arg
:
this
->
inputs
())
{
argument
a
=
arg
->
eval
();
if
(
a
.
empty
())
return
{};
args
.
push_back
(
a
);
}
return
op
.
compute
(
result
,
args
);
}
return
{};
}
instruction_ref
instruction
::
get_output_alias
(
instruction_ref
ins
)
instruction_ref
instruction
::
get_output_alias
(
instruction_ref
ins
)
{
{
auto
i
=
ins
->
get_operator
().
output_alias
(
compute_shapes
(
ins
->
inputs
()));
auto
i
=
ins
->
get_operator
().
output_alias
(
compute_shapes
(
ins
->
inputs
()));
...
...
src/targets/cpu/lowering.cpp
View file @
a3bdb08f
...
@@ -299,24 +299,7 @@ struct cpu_concat
...
@@ -299,24 +299,7 @@ struct cpu_concat
shape
compute_shape
(
const
std
::
vector
<
shape
>&
inputs
)
const
{
return
op
.
compute_shape
(
inputs
);
}
shape
compute_shape
(
const
std
::
vector
<
shape
>&
inputs
)
const
{
return
op
.
compute_shape
(
inputs
);
}
argument
compute
(
context
&
,
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
context
&
,
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
{
argument
result
{
output_shape
};
return
op
.
compute
(
output_shape
,
std
::
move
(
args
));
std
::
vector
<
std
::
size_t
>
coffsets
=
op
.
compute_offsets
(
output_shape
,
args
);
for
(
std
::
size_t
l
=
0
;
l
<
args
.
size
();
l
++
)
{
auto
argl
=
args
[
l
];
std
::
size_t
nelements
=
argl
.
get_shape
().
elements
();
visit_all
(
result
,
argl
)([
&
](
auto
output
,
auto
input
)
{
auto
slice_shape
=
shape
{
output_shape
.
type
(),
input
.
get_shape
().
lens
(),
output_shape
.
strides
()};
auto
slice
=
make_view
(
slice_shape
,
output
.
data
()
+
coffsets
[
l
]);
// cppcheck-suppress useStlAlgorithm
for
(
std
::
size_t
i
=
0
;
i
<
nelements
;
i
++
)
{
slice
[
i
]
=
input
[
i
];
}
});
}
return
result
;
}
}
};
};
...
...
test/const_eval_test.cpp
0 → 100644
View file @
a3bdb08f
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <sstream>
#include "test.hpp"
#include <basic_ops.hpp>
struct
sum_cf_op
{
std
::
string
name
()
const
{
return
"sum_cf"
;
}
migraphx
::
argument
compute
(
const
migraphx
::
shape
&
,
std
::
vector
<
migraphx
::
argument
>
args
)
const
{
migraphx
::
argument
result
;
if
(
args
.
size
()
!=
2
)
MIGRAPHX_THROW
(
"Wrong args"
);
if
(
args
[
0
].
get_shape
()
!=
args
[
1
].
get_shape
())
MIGRAPHX_THROW
(
"Wrong args"
);
if
(
args
[
0
].
get_shape
().
lens
().
size
()
!=
1
)
MIGRAPHX_THROW
(
"Wrong args"
);
if
(
args
[
0
].
get_shape
().
lens
().
front
()
!=
1
)
MIGRAPHX_THROW
(
"Wrong args"
);
args
[
0
].
visit_at
([
&
](
auto
x
)
{
args
[
1
].
visit_at
([
&
](
auto
y
)
{
result
=
migraphx
::
literal
{
x
+
y
}.
get_argument
();
});
});
return
result
;
}
migraphx
::
shape
compute_shape
(
std
::
vector
<
migraphx
::
shape
>
inputs
)
const
{
if
(
inputs
.
size
()
!=
2
)
MIGRAPHX_THROW
(
"Wrong inputs"
);
return
inputs
.
front
();
}
};
struct
non_computable_cf
{
std
::
string
name
()
const
{
return
"non_computable"
;
}
migraphx
::
shape
compute_shape
(
std
::
vector
<
migraphx
::
shape
>
inputs
)
const
{
if
(
inputs
.
empty
())
return
{};
return
inputs
.
front
();
}
};
struct
test_context
{
void
finish
()
const
{}
};
TEST_CASE
(
literal_test
)
{
migraphx
::
program
p
;
auto
lit
=
p
.
add_literal
(
1
);
CHECK
(
lit
->
eval
()
==
migraphx
::
literal
{
1
});
}
TEST_CASE
(
param_test
)
{
migraphx
::
program
p
;
auto
lit
=
p
.
add_parameter
(
"param"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
}});
CHECK
(
lit
->
eval
().
empty
());
}
TEST_CASE
(
op_test1
)
{
migraphx
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
auto
two
=
p
.
add_literal
(
2
);
auto
sum
=
p
.
add_instruction
(
sum_cf_op
{},
one
,
two
);
CHECK
(
sum
->
eval
()
==
migraphx
::
literal
{
3
});
}
TEST_CASE
(
op_test2
)
{
migraphx
::
program
p
;
auto
x
=
p
.
add_parameter
(
"param"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
}});
auto
two
=
p
.
add_literal
(
2
);
auto
sum
=
p
.
add_instruction
(
sum_cf_op
{},
x
,
two
);
CHECK
(
sum
->
eval
().
empty
());
}
TEST_CASE
(
op_test3
)
{
migraphx
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
auto
two
=
p
.
add_literal
(
2
);
auto
sum1
=
p
.
add_instruction
(
sum_op
{},
one
,
two
);
auto
sum2
=
p
.
add_instruction
(
sum_cf_op
{},
sum1
,
two
);
CHECK
(
sum2
->
eval
().
empty
());
}
TEST_CASE
(
compute_op_c
)
{
migraphx
::
operation
op
=
sum_op
{};
auto
one
=
migraphx
::
literal
{
1
}.
get_argument
();
auto
two
=
migraphx
::
literal
{
2
}.
get_argument
();
EXPECT
(
test
::
throws
([
&
]
{
op
.
compute
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
}},
{
one
,
two
});
}));
}
TEST_CASE
(
compute_nop_c
)
{
migraphx
::
operation
op
=
non_computable_cf
{};
auto
one
=
migraphx
::
literal
{
1
}.
get_argument
();
auto
two
=
migraphx
::
literal
{
2
}.
get_argument
();
EXPECT
(
test
::
throws
([
&
]
{
op
.
compute
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
}},
{
one
,
two
});
}));
}
TEST_CASE
(
compute_nop_context
)
{
migraphx
::
operation
op
=
non_computable_cf
{};
auto
one
=
migraphx
::
literal
{
1
}.
get_argument
();
auto
two
=
migraphx
::
literal
{
2
}.
get_argument
();
migraphx
::
context
ctx
=
test_context
{};
EXPECT
(
test
::
throws
([
&
]
{
op
.
compute
(
ctx
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
}},
{
one
,
two
});
}));
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
tools/include/operation.hpp
View file @
a3bdb08f
...
@@ -53,6 +53,9 @@ struct operation
...
@@ -53,6 +53,9 @@ struct operation
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
operation
&
op
);
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
operation
&
op
);
};
};
/// Returns true if operation does not require a context to run compute
bool
is_context_free
(
const
operation
&
x
);
#else
#else
namespace
operation_stream
{
namespace
operation_stream
{
...
@@ -89,7 +92,7 @@ auto operator==(const T& x, const U& y) -> decltype(x.name() == y.name())
...
@@ -89,7 +92,7 @@ auto operator==(const T& x, const U& y) -> decltype(x.name() == y.name())
}
// namespace operation_equal
}
// namespace operation_equal
template
<
class
T
>
template
<
class
T
>
auto
compute_op
(
rank
<
1
>
,
auto
compute_op
(
rank
<
2
>
,
const
T
&
x
,
const
T
&
x
,
context
&
ctx
,
context
&
ctx
,
const
shape
&
output_shape
,
const
shape
&
output_shape
,
...
@@ -99,6 +102,14 @@ auto compute_op(rank<1>,
...
@@ -99,6 +102,14 @@ auto compute_op(rank<1>,
return
x
.
compute
(
auto_any_cast
(
ctx
),
output_shape
,
input
);
return
x
.
compute
(
auto_any_cast
(
ctx
),
output_shape
,
input
);
}
}
template
<
class
T
>
auto
compute_op
(
rank
<
1
>
,
const
T
&
x
,
context
&
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
input
)
->
decltype
(
x
.
compute
(
output_shape
,
input
))
{
return
x
.
compute
(
output_shape
,
input
);
}
template
<
class
T
>
template
<
class
T
>
argument
compute_op
(
rank
<
0
>
,
const
T
&
x
,
context
&
,
const
shape
&
,
const
std
::
vector
<
argument
>&
)
argument
compute_op
(
rank
<
0
>
,
const
T
&
x
,
context
&
,
const
shape
&
,
const
std
::
vector
<
argument
>&
)
{
{
...
@@ -110,7 +121,53 @@ template <class T>
...
@@ -110,7 +121,53 @@ template <class T>
argument
argument
compute_op
(
const
T
&
x
,
context
&
ctx
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
input
)
compute_op
(
const
T
&
x
,
context
&
ctx
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
input
)
{
{
return
compute_op
(
rank
<
1
>
{},
x
,
ctx
,
output_shape
,
input
);
return
compute_op
(
rank
<
2
>
{},
x
,
ctx
,
output_shape
,
input
);
}
template
<
class
T
>
auto
compute_op
(
rank
<
2
>
,
const
T
&
x
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
input
)
->
decltype
(
x
.
compute
(
output_shape
,
input
))
{
return
x
.
compute
(
output_shape
,
input
);
}
template
<
class
T
>
auto
compute_op
(
rank
<
1
>
,
const
T
&
x
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
input
)
->
decltype
(
x
.
compute
(
auto_any_cast
(
std
::
declval
<
context
&>
()),
output_shape
,
input
))
{
std
::
string
name
=
x
.
name
();
MIGRAPHX_THROW
(
"Not computable without a context: "
+
name
);
}
template
<
class
T
>
argument
compute_op
(
rank
<
0
>
,
const
T
&
x
,
const
shape
&
,
const
std
::
vector
<
argument
>&
)
{
std
::
string
name
=
x
.
name
();
MIGRAPHX_THROW
(
"Not computable: "
+
name
);
}
template
<
class
T
>
argument
compute_op
(
const
T
&
x
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
input
)
{
return
compute_op
(
rank
<
2
>
{},
x
,
output_shape
,
input
);
}
template
<
class
T
>
auto
is_context_free_op
(
rank
<
1
>
,
const
T
&
x
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
input
)
->
decltype
(
x
.
compute
(
output_shape
,
input
),
std
::
true_type
{});
template
<
class
T
>
auto
is_context_free_op
(
rank
<
0
>
,
const
T
&
,
const
shape
&
,
const
std
::
vector
<
argument
>&
)
->
std
::
false_type
;
template
<
class
T
>
auto
is_context_free_op
(
const
T
&
x
)
->
decltype
(
is_context_free_op
(
rank
<
1
>
{},
x
,
std
::
declval
<
const
shape
&>
(),
std
::
declval
<
std
::
vector
<
argument
>>
()))
{
return
{};
}
}
template
<
class
T
>
template
<
class
T
>
...
@@ -136,6 +193,7 @@ int output_alias_op(const T& x, const std::vector<shape>& shapes)
...
@@ -136,6 +193,7 @@ int output_alias_op(const T& x, const std::vector<shape>& shapes)
interface
(
interface
(
'
operation
'
,
'
operation
'
,
virtual
(
'
name
'
,
returns
=
'
std
::
string
'
,
const
=
True
),
virtual
(
'
name
'
,
returns
=
'
std
::
string
'
,
const
=
True
),
virtual
(
'
is_context_free
'
,
returns
=
'
bool
'
,
const
=
True
,
default
=
'
is_context_free_op
'
),
virtual
(
'
output_alias
'
,
virtual
(
'
output_alias
'
,
returns
=
'
int
'
,
returns
=
'
int
'
,
input
=
'
const
std
::
vector
<
shape
>&
'
,
input
=
'
const
std
::
vector
<
shape
>&
'
,
...
@@ -149,6 +207,12 @@ int output_alias_op(const T& x, const std::vector<shape>& shapes)
...
@@ -149,6 +207,12 @@ int output_alias_op(const T& x, const std::vector<shape>& shapes)
input
=
'
const
std
::
vector
<
argument
>&
'
,
input
=
'
const
std
::
vector
<
argument
>&
'
,
const
=
True
,
const
=
True
,
default
=
'
compute_op
'
),
default
=
'
compute_op
'
),
virtual
(
'
compute
'
,
returns
=
'
argument
'
,
output
=
'
const
shape
&
'
,
input
=
'
const
std
::
vector
<
argument
>&
'
,
const
=
True
,
default
=
'
compute_op
'
),
friend
(
'
operator
<<
'
,
friend
(
'
operator
<<
'
,
returns
=
'
std
::
ostream
&
'
,
returns
=
'
std
::
ostream
&
'
,
os
=
'
std
::
ostream
&
'
,
os
=
'
std
::
ostream
&
'
,
...
@@ -165,6 +229,14 @@ int output_alias_op(const T& x, const std::vector<shape>& shapes)
...
@@ -165,6 +229,14 @@ int output_alias_op(const T& x, const std::vector<shape>& shapes)
return
!
(
x
==
y
);
return
!
(
x
==
y
);
}
}
inline
bool
is_context_free
(
const
operation
&
op
)
{
return
op
.
is_context_free
();
}
template
<
class
T
>
bool
is_context_free
(
const
T
&
x
)
{
return
is_context_free_op
(
x
);
}
#endif
#endif
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment