Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
1c8a45ee
Commit
1c8a45ee
authored
Jan 02, 2019
by
Paul
Browse files
Add eval function to instruction
parent
46b3e7da
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
214 additions
and
20 deletions
+214
-20
src/include/migraphx/instruction.hpp
src/include/migraphx/instruction.hpp
+2
-0
src/include/migraphx/operation.hpp
src/include/migraphx/operation.hpp
+98
-5
src/include/migraphx/operators.hpp
src/include/migraphx/operators.hpp
+13
-13
src/instruction.cpp
src/instruction.cpp
+20
-0
tools/include/operation.hpp
tools/include/operation.hpp
+81
-2
No files found.
src/include/migraphx/instruction.hpp
View file @
1c8a45ee
...
@@ -71,6 +71,8 @@ struct instruction
...
@@ -71,6 +71,8 @@ struct instruction
static
void
static
void
replace
(
instruction_ref
ins
,
operation
o
,
const
shape
&
r
,
std
::
vector
<
instruction_ref
>
args
);
replace
(
instruction_ref
ins
,
operation
o
,
const
shape
&
r
,
std
::
vector
<
instruction_ref
>
args
);
argument
eval
()
const
;
static
instruction_ref
get_output_alias
(
instruction_ref
ins
);
static
instruction_ref
get_output_alias
(
instruction_ref
ins
);
private:
private:
...
...
src/include/migraphx/operation.hpp
View file @
1c8a45ee
...
@@ -53,6 +53,9 @@ struct operation
...
@@ -53,6 +53,9 @@ struct operation
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
operation
&
op
);
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
operation
&
op
);
};
};
/// Returns true if operation does not require a context to run compute
bool
is_context_free
(
const
operation
&
x
);
#else
#else
namespace
operation_stream
{
namespace
operation_stream
{
...
@@ -89,7 +92,7 @@ auto operator==(const T& x, const U& y) -> decltype(x.name() == y.name())
...
@@ -89,7 +92,7 @@ auto operator==(const T& x, const U& y) -> decltype(x.name() == y.name())
}
// namespace operation_equal
}
// namespace operation_equal
template
<
class
T
>
template
<
class
T
>
auto
compute_op
(
rank
<
1
>
,
auto
compute_op
(
rank
<
2
>
,
const
T
&
x
,
const
T
&
x
,
context
&
ctx
,
context
&
ctx
,
const
shape
&
output_shape
,
const
shape
&
output_shape
,
...
@@ -99,6 +102,14 @@ auto compute_op(rank<1>,
...
@@ -99,6 +102,14 @@ auto compute_op(rank<1>,
return
x
.
compute
(
auto_any_cast
(
ctx
),
output_shape
,
input
);
return
x
.
compute
(
auto_any_cast
(
ctx
),
output_shape
,
input
);
}
}
template
<
class
T
>
auto
compute_op
(
rank
<
1
>
,
const
T
&
x
,
context
&
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
input
)
->
decltype
(
x
.
compute
(
output_shape
,
input
))
{
return
x
.
compute
(
output_shape
,
input
);
}
template
<
class
T
>
template
<
class
T
>
argument
compute_op
(
rank
<
0
>
,
const
T
&
x
,
context
&
,
const
shape
&
,
const
std
::
vector
<
argument
>&
)
argument
compute_op
(
rank
<
0
>
,
const
T
&
x
,
context
&
,
const
shape
&
,
const
std
::
vector
<
argument
>&
)
{
{
...
@@ -110,7 +121,53 @@ template <class T>
...
@@ -110,7 +121,53 @@ template <class T>
argument
argument
compute_op
(
const
T
&
x
,
context
&
ctx
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
input
)
compute_op
(
const
T
&
x
,
context
&
ctx
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
input
)
{
{
return
compute_op
(
rank
<
1
>
{},
x
,
ctx
,
output_shape
,
input
);
return
compute_op
(
rank
<
2
>
{},
x
,
ctx
,
output_shape
,
input
);
}
template
<
class
T
>
auto
compute_op
(
rank
<
2
>
,
const
T
&
x
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
input
)
->
decltype
(
x
.
compute
(
output_shape
,
input
))
{
return
x
.
compute
(
output_shape
,
input
);
}
template
<
class
T
>
auto
compute_op
(
rank
<
1
>
,
const
T
&
x
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
input
)
->
decltype
(
x
.
compute
(
auto_any_cast
(
std
::
declval
<
context
&>
()),
output_shape
,
input
))
{
std
::
string
name
=
x
.
name
();
MIGRAPHX_THROW
(
"Not computable without a context: "
+
name
);
}
template
<
class
T
>
argument
compute_op
(
rank
<
0
>
,
const
T
&
x
,
const
shape
&
,
const
std
::
vector
<
argument
>&
)
{
std
::
string
name
=
x
.
name
();
MIGRAPHX_THROW
(
"Not computable: "
+
name
);
}
template
<
class
T
>
argument
compute_op
(
const
T
&
x
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
input
)
{
return
compute_op
(
rank
<
2
>
{},
x
,
output_shape
,
input
);
}
template
<
class
T
>
auto
is_context_free_op
(
rank
<
1
>
,
const
T
&
x
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
input
)
->
decltype
(
x
.
compute
(
output_shape
,
input
),
std
::
true_type
{});
template
<
class
T
>
auto
is_context_free_op
(
rank
<
0
>
,
const
T
&
,
const
shape
&
,
const
std
::
vector
<
argument
>&
)
->
std
::
false_type
;
template
<
class
T
>
auto
is_context_free_op
(
const
T
&
x
)
->
decltype
(
is_context_free_op
(
rank
<
1
>
{},
x
,
std
::
declval
<
const
shape
&>
(),
std
::
declval
<
std
::
vector
<
argument
>>
()))
{
return
{};
}
}
template
<
class
T
>
template
<
class
T
>
...
@@ -138,9 +195,11 @@ int output_alias_op(const T& x, const std::vector<shape>& shapes)
...
@@ -138,9 +195,11 @@ int output_alias_op(const T& x, const std::vector<shape>& shapes)
* struct operation
* struct operation
* {
* {
* std::string name() const;
* std::string name() const;
* bool is_context_free() const;
* int output_alias(const std::vector<shape>& input) const;
* int output_alias(const std::vector<shape>& input) const;
* shape compute_shape(const std::vector<shape>& input) const;
* shape compute_shape(const std::vector<shape>& input) const;
* argument compute(context& ctx,const shape& output,const std::vector<argument>& input) const;
* argument compute(context& ctx,const shape& output,const std::vector<argument>& input) const;
* argument compute(const shape& output,const std::vector<argument>& input) const;
* friend std::ostream & operator<<(std::ostream & os,const operation & op) ;
* friend std::ostream & operator<<(std::ostream & os,const operation & op) ;
* friend bool operator==(const operation & x,const operation & y) ;
* friend bool operator==(const operation & x,const operation & y) ;
* };
* };
...
@@ -210,6 +269,12 @@ struct operation
...
@@ -210,6 +269,12 @@ struct operation
return
(
*
this
).
private_detail_te_get_handle
().
name
();
return
(
*
this
).
private_detail_te_get_handle
().
name
();
}
}
bool
is_context_free
()
const
{
assert
((
*
this
).
private_detail_te_handle_mem_var
);
return
(
*
this
).
private_detail_te_get_handle
().
is_context_free
();
}
int
output_alias
(
const
std
::
vector
<
shape
>&
input
)
const
int
output_alias
(
const
std
::
vector
<
shape
>&
input
)
const
{
{
assert
((
*
this
).
private_detail_te_handle_mem_var
);
assert
((
*
this
).
private_detail_te_handle_mem_var
);
...
@@ -228,6 +293,12 @@ struct operation
...
@@ -228,6 +293,12 @@ struct operation
return
(
*
this
).
private_detail_te_get_handle
().
compute
(
ctx
,
output
,
input
);
return
(
*
this
).
private_detail_te_get_handle
().
compute
(
ctx
,
output
,
input
);
}
}
argument
compute
(
const
shape
&
output
,
const
std
::
vector
<
argument
>&
input
)
const
{
assert
((
*
this
).
private_detail_te_handle_mem_var
);
return
(
*
this
).
private_detail_te_get_handle
().
compute
(
output
,
input
);
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
operation
&
op
)
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
operation
&
op
)
{
{
assert
(
op
.
private_detail_te_handle_mem_var
);
assert
(
op
.
private_detail_te_handle_mem_var
);
...
@@ -248,10 +319,12 @@ struct operation
...
@@ -248,10 +319,12 @@ struct operation
virtual
const
std
::
type_info
&
type
()
const
=
0
;
virtual
const
std
::
type_info
&
type
()
const
=
0
;
virtual
std
::
string
name
()
const
=
0
;
virtual
std
::
string
name
()
const
=
0
;
virtual
bool
is_context_free
()
const
=
0
;
virtual
int
output_alias
(
const
std
::
vector
<
shape
>&
input
)
const
=
0
;
virtual
int
output_alias
(
const
std
::
vector
<
shape
>&
input
)
const
=
0
;
virtual
shape
compute_shape
(
const
std
::
vector
<
shape
>&
input
)
const
=
0
;
virtual
shape
compute_shape
(
const
std
::
vector
<
shape
>&
input
)
const
=
0
;
virtual
argument
virtual
argument
compute
(
context
&
ctx
,
const
shape
&
output
,
const
std
::
vector
<
argument
>&
input
)
const
=
0
;
compute
(
context
&
ctx
,
const
shape
&
output
,
const
std
::
vector
<
argument
>&
input
)
const
=
0
;
virtual
argument
compute
(
const
shape
&
output
,
const
std
::
vector
<
argument
>&
input
)
const
=
0
;
virtual
std
::
ostream
&
operator_shift_left
(
std
::
ostream
&
os
)
const
=
0
;
virtual
std
::
ostream
&
operator_shift_left
(
std
::
ostream
&
os
)
const
=
0
;
virtual
bool
operator
==
(
const
operation
&
y
)
const
=
0
;
virtual
bool
operator
==
(
const
operation
&
y
)
const
=
0
;
};
};
...
@@ -286,6 +359,12 @@ struct operation
...
@@ -286,6 +359,12 @@ struct operation
std
::
string
name
()
const
override
{
return
private_detail_te_value
.
name
();
}
std
::
string
name
()
const
override
{
return
private_detail_te_value
.
name
();
}
bool
is_context_free
()
const
override
{
return
is_context_free_op
(
private_detail_te_value
);
}
int
output_alias
(
const
std
::
vector
<
shape
>&
input
)
const
override
int
output_alias
(
const
std
::
vector
<
shape
>&
input
)
const
override
{
{
...
@@ -306,6 +385,12 @@ struct operation
...
@@ -306,6 +385,12 @@ struct operation
return
compute_op
(
private_detail_te_value
,
ctx
,
output
,
input
);
return
compute_op
(
private_detail_te_value
,
ctx
,
output
,
input
);
}
}
argument
compute
(
const
shape
&
output
,
const
std
::
vector
<
argument
>&
input
)
const
override
{
return
compute_op
(
private_detail_te_value
,
output
,
input
);
}
std
::
ostream
&
operator_shift_left
(
std
::
ostream
&
os
)
const
override
std
::
ostream
&
operator_shift_left
(
std
::
ostream
&
os
)
const
override
{
{
using
migraphx
::
operation_stream
::
operator
<<
;
using
migraphx
::
operation_stream
::
operator
<<
;
...
@@ -385,6 +470,14 @@ inline const ValueType& any_cast(const operation& x)
...
@@ -385,6 +470,14 @@ inline const ValueType& any_cast(const operation& x)
inline
bool
operator
!=
(
const
operation
&
x
,
const
operation
&
y
)
{
return
!
(
x
==
y
);
}
inline
bool
operator
!=
(
const
operation
&
x
,
const
operation
&
y
)
{
return
!
(
x
==
y
);
}
inline
bool
is_context_free
(
const
operation
&
op
)
{
return
op
.
is_context_free
();
}
template
<
class
T
>
bool
is_context_free
(
const
T
&
x
)
{
return
is_context_free_op
(
x
);
}
#endif
#endif
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/include/migraphx/operators.hpp
View file @
1c8a45ee
...
@@ -16,7 +16,7 @@ namespace op {
...
@@ -16,7 +16,7 @@ namespace op {
struct
not_computable
struct
not_computable
{
{
argument
compute
(
context
&
,
const
shape
&
,
const
std
::
vector
<
argument
>&
)
const
argument
compute
(
const
shape
&
,
const
std
::
vector
<
argument
>&
)
const
{
{
MIGRAPHX_THROW
(
"not computable"
);
MIGRAPHX_THROW
(
"not computable"
);
}
}
...
@@ -296,7 +296,7 @@ struct transpose
...
@@ -296,7 +296,7 @@ struct transpose
}
}
return
{
t
,
output_lens
,
output_strides
};
return
{
t
,
output_lens
,
output_strides
};
}
}
argument
compute
(
context
&
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
{
return
{
std
::
move
(
output_shape
),
std
::
move
(
args
.
front
().
data
)};
return
{
std
::
move
(
output_shape
),
std
::
move
(
args
.
front
().
data
)};
}
}
...
@@ -437,7 +437,7 @@ struct slice
...
@@ -437,7 +437,7 @@ struct slice
}
}
return
shape
{
t
,
new_lens
,
old_strides
};
return
shape
{
t
,
new_lens
,
old_strides
};
}
}
argument
compute
(
context
&
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
{
auto
input
=
args
[
0
];
auto
input
=
args
[
0
];
auto
offset
=
compute_offset
(
input
.
get_shape
())
*
output_shape
.
type_size
();
auto
offset
=
compute_offset
(
input
.
get_shape
())
*
output_shape
.
type_size
();
...
@@ -487,7 +487,7 @@ struct squeeze
...
@@ -487,7 +487,7 @@ struct squeeze
}
}
return
shape
{
type
,
new_lens
};
return
shape
{
type
,
new_lens
};
}
}
argument
compute
(
context
&
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
{
return
{
std
::
move
(
output_shape
),
std
::
move
(
args
.
front
().
data
)};
return
{
std
::
move
(
output_shape
),
std
::
move
(
args
.
front
().
data
)};
}
}
...
@@ -526,7 +526,7 @@ struct unsqueeze
...
@@ -526,7 +526,7 @@ struct unsqueeze
}
}
return
shape
{
type
,
new_lens
};
return
shape
{
type
,
new_lens
};
}
}
argument
compute
(
context
&
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
{
return
{
std
::
move
(
output_shape
),
std
::
move
(
args
.
front
().
data
)};
return
{
std
::
move
(
output_shape
),
std
::
move
(
args
.
front
().
data
)};
}
}
...
@@ -578,7 +578,7 @@ struct reshape
...
@@ -578,7 +578,7 @@ struct reshape
MIGRAPHX_THROW
(
"Wrong number of elements for reshape"
);
MIGRAPHX_THROW
(
"Wrong number of elements for reshape"
);
return
s
;
return
s
;
}
}
argument
compute
(
context
&
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
{
return
{
std
::
move
(
output_shape
),
std
::
move
(
args
.
front
().
data
)};
return
{
std
::
move
(
output_shape
),
std
::
move
(
args
.
front
().
data
)};
}
}
...
@@ -624,7 +624,7 @@ struct identity
...
@@ -624,7 +624,7 @@ struct identity
{
{
std
::
string
name
()
const
{
return
"identity"
;
}
std
::
string
name
()
const
{
return
"identity"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
return
inputs
.
at
(
0
);
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
return
inputs
.
at
(
0
);
}
argument
compute
(
context
&
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
{
return
{
std
::
move
(
output_shape
),
std
::
move
(
args
.
at
(
0
).
data
)};
return
{
std
::
move
(
output_shape
),
std
::
move
(
args
.
at
(
0
).
data
)};
}
}
...
@@ -742,7 +742,7 @@ struct flatten
...
@@ -742,7 +742,7 @@ struct flatten
std
::
accumulate
(
lens
.
begin
()
+
axis
,
lens
.
end
(),
std
::
size_t
{
1
},
std
::
multiplies
<>
{});
std
::
accumulate
(
lens
.
begin
()
+
axis
,
lens
.
end
(),
std
::
size_t
{
1
},
std
::
multiplies
<>
{});
return
{
inputs
.
at
(
0
).
type
(),
{
x
,
y
}};
return
{
inputs
.
at
(
0
).
type
(),
{
x
,
y
}};
}
}
argument
compute
(
context
&
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
{
return
{
std
::
move
(
output_shape
),
std
::
move
(
args
.
front
().
data
)};
return
{
std
::
move
(
output_shape
),
std
::
move
(
args
.
front
().
data
)};
}
}
...
@@ -794,7 +794,7 @@ struct broadcast
...
@@ -794,7 +794,7 @@ struct broadcast
return
{
t
,
broadcast_shape
.
lens
(),
std
::
move
(
bcast_strides
)};
return
{
t
,
broadcast_shape
.
lens
(),
std
::
move
(
bcast_strides
)};
}
}
}
}
argument
compute
(
context
&
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
{
return
{
std
::
move
(
output_shape
),
std
::
move
(
args
.
at
(
0
).
data
)};
return
{
std
::
move
(
output_shape
),
std
::
move
(
args
.
at
(
0
).
data
)};
}
}
...
@@ -836,7 +836,7 @@ struct multibroadcast
...
@@ -836,7 +836,7 @@ struct multibroadcast
}
}
return
{
t
,
output_lens
,
bcast_strides
};
return
{
t
,
output_lens
,
bcast_strides
};
}
}
argument
compute
(
context
&
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
{
return
{
std
::
move
(
output_shape
),
std
::
move
(
args
.
at
(
0
).
data
)};
return
{
std
::
move
(
output_shape
),
std
::
move
(
args
.
at
(
0
).
data
)};
}
}
...
@@ -858,7 +858,7 @@ struct scalar
...
@@ -858,7 +858,7 @@ struct scalar
return
{
t
,
scalar_bcast
.
lens
(),
strides
};
return
{
t
,
scalar_bcast
.
lens
(),
strides
};
}
}
argument
compute
(
context
&
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
{
return
{
std
::
move
(
output_shape
),
std
::
move
(
args
.
at
(
0
).
data
)};
return
{
std
::
move
(
output_shape
),
std
::
move
(
args
.
at
(
0
).
data
)};
}
}
...
@@ -923,7 +923,7 @@ struct load
...
@@ -923,7 +923,7 @@ struct load
check_shapes
{
inputs
}.
has
(
1
);
check_shapes
{
inputs
}.
has
(
1
);
return
s
;
return
s
;
}
}
argument
compute
(
context
&
,
const
shape
&
,
const
std
::
vector
<
argument
>&
args
)
const
argument
compute
(
const
shape
&
,
const
std
::
vector
<
argument
>&
args
)
const
{
{
return
{
s
,
args
[
0
].
data
()
+
offset
};
return
{
s
,
args
[
0
].
data
()
+
offset
};
}
}
...
@@ -946,7 +946,7 @@ struct outline
...
@@ -946,7 +946,7 @@ struct outline
check_shapes
{
inputs
,
*
this
}.
has
(
0
);
check_shapes
{
inputs
,
*
this
}.
has
(
0
);
return
s
;
return
s
;
}
}
argument
compute
(
context
&
,
const
shape
&
,
const
std
::
vector
<
argument
>&
)
const
argument
compute
(
const
shape
&
,
const
std
::
vector
<
argument
>&
)
const
{
{
return
{
s
,
nullptr
};
return
{
s
,
nullptr
};
}
}
...
...
src/instruction.cpp
View file @
1c8a45ee
...
@@ -170,6 +170,26 @@ std::vector<shape> compute_shapes(const std::vector<instruction_ref>& args)
...
@@ -170,6 +170,26 @@ std::vector<shape> compute_shapes(const std::vector<instruction_ref>& args)
return
shapes
;
return
shapes
;
}
}
argument
instruction
::
eval
()
const
{
if
(
op
.
name
()
==
"@literal"
)
{
return
this
->
get_literal
().
get_argument
();
}
if
(
is_context_free
(
op
))
{
std
::
vector
<
argument
>
args
;
for
(
auto
&&
arg
:
this
->
inputs
())
{
argument
a
=
arg
->
eval
();
if
(
a
.
empty
())
return
{};
args
.
push_back
(
a
);
}
return
op
.
compute
(
result
,
args
);
}
return
{};
}
instruction_ref
instruction
::
get_output_alias
(
instruction_ref
ins
)
instruction_ref
instruction
::
get_output_alias
(
instruction_ref
ins
)
{
{
auto
i
=
ins
->
get_operator
().
output_alias
(
compute_shapes
(
ins
->
inputs
()));
auto
i
=
ins
->
get_operator
().
output_alias
(
compute_shapes
(
ins
->
inputs
()));
...
...
tools/include/operation.hpp
View file @
1c8a45ee
...
@@ -53,6 +53,9 @@ struct operation
...
@@ -53,6 +53,9 @@ struct operation
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
operation
&
op
);
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
operation
&
op
);
};
};
/// Returns true if operation does not require a context to run compute
bool
is_context_free
(
const
operation
&
x
);
#else
#else
namespace
operation_stream
{
namespace
operation_stream
{
...
@@ -89,7 +92,7 @@ auto operator==(const T& x, const U& y) -> decltype(x.name() == y.name())
...
@@ -89,7 +92,7 @@ auto operator==(const T& x, const U& y) -> decltype(x.name() == y.name())
}
// namespace operation_equal
}
// namespace operation_equal
template
<
class
T
>
template
<
class
T
>
auto
compute_op
(
rank
<
1
>
,
auto
compute_op
(
rank
<
2
>
,
const
T
&
x
,
const
T
&
x
,
context
&
ctx
,
context
&
ctx
,
const
shape
&
output_shape
,
const
shape
&
output_shape
,
...
@@ -99,6 +102,17 @@ auto compute_op(rank<1>,
...
@@ -99,6 +102,17 @@ auto compute_op(rank<1>,
return
x
.
compute
(
auto_any_cast
(
ctx
),
output_shape
,
input
);
return
x
.
compute
(
auto_any_cast
(
ctx
),
output_shape
,
input
);
}
}
template
<
class
T
>
auto
compute_op
(
rank
<
1
>
,
const
T
&
x
,
context
&
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
input
)
->
decltype
(
x
.
compute
(
output_shape
,
input
))
{
return
x
.
compute
(
output_shape
,
input
);
}
template
<
class
T
>
template
<
class
T
>
argument
compute_op
(
rank
<
0
>
,
const
T
&
x
,
context
&
,
const
shape
&
,
const
std
::
vector
<
argument
>&
)
argument
compute_op
(
rank
<
0
>
,
const
T
&
x
,
context
&
,
const
shape
&
,
const
std
::
vector
<
argument
>&
)
{
{
...
@@ -110,7 +124,54 @@ template <class T>
...
@@ -110,7 +124,54 @@ template <class T>
argument
argument
compute_op
(
const
T
&
x
,
context
&
ctx
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
input
)
compute_op
(
const
T
&
x
,
context
&
ctx
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
input
)
{
{
return
compute_op
(
rank
<
1
>
{},
x
,
ctx
,
output_shape
,
input
);
return
compute_op
(
rank
<
2
>
{},
x
,
ctx
,
output_shape
,
input
);
}
template
<
class
T
>
auto
compute_op
(
rank
<
2
>
,
const
T
&
x
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
input
)
->
decltype
(
x
.
compute
(
output_shape
,
input
))
{
return
x
.
compute
(
output_shape
,
input
);
}
template
<
class
T
>
auto
compute_op
(
rank
<
1
>
,
const
T
&
x
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
input
)
->
decltype
(
x
.
compute
(
auto_any_cast
(
std
::
declval
<
context
&>
()),
output_shape
,
input
))
{
std
::
string
name
=
x
.
name
();
MIGRAPHX_THROW
(
"Not computable without a context: "
+
name
);
}
template
<
class
T
>
argument
compute_op
(
rank
<
0
>
,
const
T
&
x
,
const
shape
&
,
const
std
::
vector
<
argument
>&
)
{
std
::
string
name
=
x
.
name
();
MIGRAPHX_THROW
(
"Not computable: "
+
name
);
}
template
<
class
T
>
argument
compute_op
(
const
T
&
x
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
input
)
{
return
compute_op
(
rank
<
2
>
{},
x
,
output_shape
,
input
);
}
template
<
class
T
>
auto
is_context_free_op
(
rank
<
1
>
,
const
T
&
x
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
input
)
->
decltype
(
x
.
compute
(
output_shape
,
input
),
std
::
true_type
{});
template
<
class
T
>
auto
is_context_free_op
(
rank
<
0
>
,
const
T
&
,
const
shape
&
,
const
std
::
vector
<
argument
>&
)
->
std
::
false_type
;
template
<
class
T
>
auto
is_context_free_op
(
const
T
&
x
)
->
decltype
(
is_context_free_op
(
rank
<
1
>
{},
x
,
std
::
declval
<
const
shape
&>
(),
std
::
declval
<
std
::
vector
<
argument
>>
()))
{
return
{};
}
}
template
<
class
T
>
template
<
class
T
>
...
@@ -136,6 +197,7 @@ int output_alias_op(const T& x, const std::vector<shape>& shapes)
...
@@ -136,6 +197,7 @@ int output_alias_op(const T& x, const std::vector<shape>& shapes)
interface
(
interface
(
'
operation
'
,
'
operation
'
,
virtual
(
'
name
'
,
returns
=
'
std
::
string
'
,
const
=
True
),
virtual
(
'
name
'
,
returns
=
'
std
::
string
'
,
const
=
True
),
virtual
(
'
is_context_free
'
,
returns
=
'
bool
'
,
const
=
True
,
default
=
'
is_context_free_op
'
),
virtual
(
'
output_alias
'
,
virtual
(
'
output_alias
'
,
returns
=
'
int
'
,
returns
=
'
int
'
,
input
=
'
const
std
::
vector
<
shape
>&
'
,
input
=
'
const
std
::
vector
<
shape
>&
'
,
...
@@ -149,6 +211,12 @@ int output_alias_op(const T& x, const std::vector<shape>& shapes)
...
@@ -149,6 +211,12 @@ int output_alias_op(const T& x, const std::vector<shape>& shapes)
input
=
'
const
std
::
vector
<
argument
>&
'
,
input
=
'
const
std
::
vector
<
argument
>&
'
,
const
=
True
,
const
=
True
,
default
=
'
compute_op
'
),
default
=
'
compute_op
'
),
virtual
(
'
compute
'
,
returns
=
'
argument
'
,
output
=
'
const
shape
&
'
,
input
=
'
const
std
::
vector
<
argument
>&
'
,
const
=
True
,
default
=
'
compute_op
'
),
friend
(
'
operator
<<
'
,
friend
(
'
operator
<<
'
,
returns
=
'
std
::
ostream
&
'
,
returns
=
'
std
::
ostream
&
'
,
os
=
'
std
::
ostream
&
'
,
os
=
'
std
::
ostream
&
'
,
...
@@ -165,6 +233,17 @@ int output_alias_op(const T& x, const std::vector<shape>& shapes)
...
@@ -165,6 +233,17 @@ int output_alias_op(const T& x, const std::vector<shape>& shapes)
return
!
(
x
==
y
);
return
!
(
x
==
y
);
}
}
inline
bool
is_context_free
(
const
operation
&
op
)
{
return
op
.
is_context_free
();
}
template
<
class
T
>
bool
is_context_free
(
const
T
&
x
)
{
return
is_context_free_op
(
x
);
}
#endif
#endif
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment