Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
68c17b1b
Commit
68c17b1b
authored
Sep 23, 2022
by
charlie
Browse files
Still broken, figuring things out
parent
f02f5d98
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
170 additions
and
136 deletions
+170
-136
src/include/migraphx/operation.hpp
src/include/migraphx/operation.hpp
+64
-66
src/include/migraphx/shape.hpp
src/include/migraphx/shape.hpp
+3
-0
src/shape.cpp
src/shape.cpp
+2
-0
src/targets/gpu/jit/scatternd.cpp
src/targets/gpu/jit/scatternd.cpp
+4
-3
src/targets/ref/lowering.cpp
src/targets/ref/lowering.cpp
+1
-1
tools/include/operation.hpp
tools/include/operation.hpp
+96
-66
No files found.
src/include/migraphx/operation.hpp
View file @
68c17b1b
...
@@ -32,6 +32,7 @@
...
@@ -32,6 +32,7 @@
#include <utility>
#include <utility>
#include <unordered_map>
#include <unordered_map>
#include <migraphx/reflect.hpp>
#include <migraphx/reflect.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/normalize_attributes.hpp>
#include <migraphx/normalize_attributes.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/argument.hpp>
...
@@ -46,6 +47,48 @@ inline namespace MIGRAPHX_INLINE_NS {
...
@@ -46,6 +47,48 @@ inline namespace MIGRAPHX_INLINE_NS {
struct
context
;
struct
context
;
struct
dyn_output
{
// original shape from the instruction
shape
ins_shape
;
// shape computed at eval time using input arguments
shape
computed_shape
;
};
/**
* Handle dynamic and static shape at evaluation time.
* If converted to shape type, returns original ins_shape.
* If converted to dyn_output type, will compute an output shape using the input arguments.
*/
template
<
class
F
>
struct
compute_output_shape
{
F
ins_inputs
;
operator
dyn_output
()
const
{
return
unpack
(
[](
const
auto
&
x
,
shape
ins_shape
,
const
std
::
vector
<
argument
>&
inputs
)
{
return
dyn_output
{
ins_shape
,
compute_shape
(
x
,
to_shapes
(
inputs
))};
},
ins_inputs
);
}
operator
shape
()
const
{
return
unpack
(
[](
const
auto
&
,
shape
ins_shape
,
const
std
::
vector
<
argument
>&
)
{
return
ins_shape
;
},
ins_inputs
);
}
};
template
<
class
T
>
auto
make_compute_output_shape
(
const
T
&
x
,
shape
output_shape
,
const
std
::
vector
<
argument
>&
inputs
)
->
decltype
(
compute_output_shape
{
pack
(
x
,
output_shape
,
inputs
)})
{
return
compute_output_shape
{
pack
(
x
,
output_shape
,
inputs
)};
}
#ifdef DOXYGEN
#ifdef DOXYGEN
/// The operation interface represents an action an instruction will perform. All
/// The operation interface represents an action an instruction will perform. All
...
@@ -199,9 +242,11 @@ auto compute_op(rank<1>,
...
@@ -199,9 +242,11 @@ auto compute_op(rank<1>,
context
&
ctx
,
context
&
ctx
,
const
shape
&
output_shape
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
input
)
const
std
::
vector
<
argument
>&
input
)
->
decltype
(
x
.
compute
(
auto_any_cast
(
ctx
),
output_shape
,
input
))
->
decltype
(
x
.
compute
(
auto_any_cast
(
ctx
),
make_compute_output_shape
(
x
,
output_shape
,
input
),
input
))
{
{
return
x
.
compute
(
auto_any_cast
(
ctx
),
output_shape
,
input
);
return
x
.
compute
(
auto_any_cast
(
ctx
),
make_compute_output_shape
(
x
,
output_shape
,
input
)
,
input
);
}
}
template
<
class
T
>
template
<
class
T
>
...
@@ -220,7 +265,7 @@ compute_op(const T& x, context& ctx, const shape& output_shape, const std::vecto
...
@@ -220,7 +265,7 @@ compute_op(const T& x, context& ctx, const shape& output_shape, const std::vecto
template
<
class
T
>
template
<
class
T
>
auto
compute_op
(
rank
<
1
>
,
const
T
&
x
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
input
)
auto
compute_op
(
rank
<
1
>
,
const
T
&
x
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
input
)
->
decltype
(
x
.
compute
(
output_shape
,
input
))
->
decltype
(
x
.
compute
(
make_compute_output_shape
(
x
,
output_shape
,
input
)
,
input
))
{
{
return
x
.
compute
(
make_compute_output_shape
(
x
,
output_shape
,
input
),
input
);
return
x
.
compute
(
make_compute_output_shape
(
x
,
output_shape
,
input
),
input
);
}
}
...
@@ -244,9 +289,10 @@ auto compute_op(rank<1>,
...
@@ -244,9 +289,10 @@ auto compute_op(rank<1>,
const
shape
&
output
,
const
shape
&
output
,
const
std
::
vector
<
argument
>&
inputs
,
const
std
::
vector
<
argument
>&
inputs
,
const
std
::
vector
<
module_ref
>&
module_args
,
const
std
::
vector
<
module_ref
>&
module_args
,
F
f
)
->
decltype
(
x
.
compute
(
output
,
inputs
,
module_args
,
f
))
F
f
)
->
decltype
(
x
.
compute
(
make_compute_output_shape
(
x
,
output
,
inputs
),
inputs
,
module_args
,
f
))
{
{
return
x
.
compute
(
output
,
inputs
,
module_args
,
f
);
return
x
.
compute
(
make_compute_output_shape
(
x
,
output
,
inputs
)
,
inputs
,
module_args
,
f
);
}
}
template
<
class
T
,
class
F
>
template
<
class
T
,
class
F
>
...
@@ -278,9 +324,12 @@ auto compute_op(rank<4>,
...
@@ -278,9 +324,12 @@ auto compute_op(rank<4>,
const
shape
&
output
,
const
shape
&
output
,
const
std
::
vector
<
argument
>&
inputs
,
const
std
::
vector
<
argument
>&
inputs
,
const
std
::
vector
<
module_ref
>&
module_args
,
const
std
::
vector
<
module_ref
>&
module_args
,
F
f
)
->
decltype
(
x
.
compute
(
auto_any_cast
(
ctx
),
output
,
inputs
,
module_args
,
f
))
F
f
)
->
decltype
(
x
.
compute
(
auto_any_cast
(
ctx
),
make_compute_output_shape
(
x
,
output
,
inputs
),
inputs
,
module_args
,
f
))
{
{
return
x
.
compute
(
auto_any_cast
(
ctx
),
output
,
inputs
,
module_args
,
f
);
return
x
.
compute
(
auto_any_cast
(
ctx
),
make_compute_output_shape
(
x
,
output
,
inputs
),
inputs
,
module_args
,
f
);
}
}
template
<
class
T
,
class
F
>
template
<
class
T
,
class
F
>
...
@@ -290,9 +339,10 @@ auto compute_op(rank<3>,
...
@@ -290,9 +339,10 @@ auto compute_op(rank<3>,
const
shape
&
output
,
const
shape
&
output
,
const
std
::
vector
<
argument
>&
inputs
,
const
std
::
vector
<
argument
>&
inputs
,
const
std
::
vector
<
module_ref
>&
module_args
,
const
std
::
vector
<
module_ref
>&
module_args
,
F
f
)
->
decltype
(
x
.
compute
(
output
,
inputs
,
module_args
,
f
))
F
f
)
->
decltype
(
x
.
compute
(
make_compute_output_shape
(
x
,
output
,
inputs
),
inputs
,
module_args
,
f
))
{
{
return
x
.
compute
(
output
,
inputs
,
module_args
,
f
);
return
x
.
compute
(
make_compute_output_shape
(
x
,
output
,
inputs
)
,
inputs
,
module_args
,
f
);
}
}
template
<
class
T
,
class
F
>
template
<
class
T
,
class
F
>
...
@@ -302,9 +352,9 @@ auto compute_op(rank<2>,
...
@@ -302,9 +352,9 @@ auto compute_op(rank<2>,
const
shape
&
output
,
const
shape
&
output
,
const
std
::
vector
<
argument
>&
inputs
,
const
std
::
vector
<
argument
>&
inputs
,
const
std
::
vector
<
module_ref
>&
,
const
std
::
vector
<
module_ref
>&
,
F
)
->
decltype
(
x
.
compute
(
output
,
inputs
))
F
)
->
decltype
(
x
.
compute
(
make_compute_output_shape
(
x
,
output
,
inputs
)
,
inputs
))
{
{
return
x
.
compute
(
output
,
inputs
);
return
x
.
compute
(
make_compute_output_shape
(
x
,
output
,
inputs
)
,
inputs
);
}
}
template
<
class
T
,
class
F
>
template
<
class
T
,
class
F
>
...
@@ -314,9 +364,10 @@ auto compute_op(rank<1>,
...
@@ -314,9 +364,10 @@ auto compute_op(rank<1>,
const
shape
&
output
,
const
shape
&
output
,
const
std
::
vector
<
argument
>&
inputs
,
const
std
::
vector
<
argument
>&
inputs
,
const
std
::
vector
<
module_ref
>&
,
const
std
::
vector
<
module_ref
>&
,
F
)
->
decltype
(
x
.
compute
(
auto_any_cast
(
ctx
),
output
,
inputs
))
F
)
->
decltype
(
x
.
compute
(
auto_any_cast
(
ctx
),
make_compute_output_shape
(
x
,
output
,
inputs
),
inputs
))
{
{
return
x
.
compute
(
auto_any_cast
(
ctx
),
output
,
inputs
);
return
x
.
compute
(
auto_any_cast
(
ctx
),
make_compute_output_shape
(
x
,
output
,
inputs
)
,
inputs
);
}
}
template
<
class
T
,
class
F
>
template
<
class
T
,
class
F
>
...
@@ -1278,59 +1329,6 @@ inline const ValueType& any_cast(const operation& x)
...
@@ -1278,59 +1329,6 @@ inline const ValueType& any_cast(const operation& x)
inline
bool
operator
!=
(
const
operation
&
x
,
const
operation
&
y
)
{
return
not
(
x
==
y
);
}
inline
bool
operator
!=
(
const
operation
&
x
,
const
operation
&
y
)
{
return
not
(
x
==
y
);
}
// used for dynamic operators
struct
dyn_output
{
// original instruction output shape
shape
ins_shape
;
std
::
function
<
shape
()
>
compute_shape
;
shape
get_output_shape
()
{
if
(
output_shape
.
element_space
()
==
0
)
{
output_shape
=
compute_shape
();
}
return
output_shape
;
}
private:
// shape computed at eval time using input arguments
shape
output_shape
;
};
/**
* Handle dynamic and static shape at evaluation time.
* If converted to shape type, returns original ins_shape
* If converted to dyn_output type, will compute an output shape using the input arguments
*/
template
<
class
F
>
struct
compute_output_shape
{
F
ins_inputs
;
operator
dyn_output
()
const
{
return
unpack
(
[](
const
auto
&
x
,
shape
ins_shape
,
const
std
::
vector
<
argument
>&
args
)
{
return
dyn_output
{
ins_shape
,
[
&
]()
{
compute_shape
(
x
,
to_shapes
(
args
));
}};
},
ins_inputs
);
}
operator
shape
()
const
{
return
unpack
(
[](
const
auto
&
,
shape
ins_shape
,
const
std
::
vector
<
argument
>&
)
{
return
ins_shape
;
},
ins_inputs
);
}
};
template
<
class
T
>
auto
make_compute_output_shape
(
const
T
&
x
,
shape
ins_shape
,
const
std
::
vector
<
argument
>&
input
)
{
return
compute_output_shape
{
pack
(
x
,
ins_shape
,
input
)};
}
inline
value
inline
value
compile
(
operation
&
op
,
context
&
ctx
,
const
shape
&
output_shape
,
const
std
::
vector
<
shape
>&
input
)
compile
(
operation
&
op
,
context
&
ctx
,
const
shape
&
output_shape
,
const
std
::
vector
<
shape
>&
input
)
{
{
...
...
src/include/migraphx/shape.hpp
View file @
68c17b1b
...
@@ -214,6 +214,9 @@ struct shape
...
@@ -214,6 +214,9 @@ struct shape
/// Return true if the shape is dynamic
/// Return true if the shape is dynamic
bool
dynamic
()
const
;
bool
dynamic
()
const
;
/// Returns true if the shape is empty
bool
empty
()
const
;
shape
normalize_standard
()
const
;
shape
normalize_standard
()
const
;
shape
with_lens
(
type_t
t
,
const
std
::
vector
<
std
::
size_t
>&
l
)
const
;
shape
with_lens
(
type_t
t
,
const
std
::
vector
<
std
::
size_t
>&
l
)
const
;
...
...
src/shape.cpp
View file @
68c17b1b
...
@@ -443,6 +443,8 @@ std::string shape::type_string() const { return name(this->type()); }
...
@@ -443,6 +443,8 @@ std::string shape::type_string() const { return name(this->type()); }
bool
shape
::
dynamic
()
const
{
return
not
impl
->
m_dyn_dims
.
empty
();
}
bool
shape
::
dynamic
()
const
{
return
not
impl
->
m_dyn_dims
.
empty
();
}
bool
shape
::
empty
()
const
{
return
max_lens
().
empty
();
}
const
std
::
vector
<
shape
::
dynamic_dimension
>&
shape
::
dyn_dims
()
const
{
return
impl
->
m_dyn_dims
;
}
const
std
::
vector
<
shape
::
dynamic_dimension
>&
shape
::
dyn_dims
()
const
{
return
impl
->
m_dyn_dims
;
}
std
::
vector
<
std
::
size_t
>
shape
::
min_lens
()
const
std
::
vector
<
std
::
size_t
>
shape
::
min_lens
()
const
...
...
src/targets/gpu/jit/scatternd.cpp
View file @
68c17b1b
...
@@ -79,9 +79,10 @@ struct scatternd_compiler : compiler<scatternd_compiler>
...
@@ -79,9 +79,10 @@ struct scatternd_compiler : compiler<scatternd_compiler>
{
{
assert
(
starts_with
(
op
.
name
(),
"scatternd_"
));
assert
(
starts_with
(
op
.
name
(),
"scatternd_"
));
auto
reduction
=
op
.
name
().
substr
(
10
);
auto
reduction
=
op
.
name
().
substr
(
10
);
return
insert
(
compile_op
(
ctx
,
return
insert
(
compile_op
(
to_shapes
({
ins
->
inputs
().
begin
()
+
1
,
ins
->
inputs
().
end
()}),
ctx
,
{{
"reduction"
,
reduction
}}));
to_shapes
(
std
::
vector
<
instruction_ref
>
{
ins
->
inputs
().
begin
()
+
1
,
ins
->
inputs
().
end
()}),
{{
"reduction"
,
reduction
}}));
}
}
compiler_replace
insert
(
const
operation
&
op
)
const
compiler_replace
insert
(
const
operation
&
op
)
const
...
...
src/targets/ref/lowering.cpp
View file @
68c17b1b
...
@@ -252,7 +252,7 @@ struct ref_convolution : auto_register_op<ref_convolution<Op>>
...
@@ -252,7 +252,7 @@ struct ref_convolution : auto_register_op<ref_convolution<Op>>
else
else
{
{
padding
=
op
.
padding
;
padding
=
op
.
padding
;
output_shape
=
dyn_output
.
get_out
put_shape
()
;
output_shape
=
dyn_output
.
com
put
ed
_shape
;
}
}
argument
result
{
output_shape
};
argument
result
{
output_shape
};
...
...
tools/include/operation.hpp
View file @
68c17b1b
...
@@ -32,6 +32,7 @@
...
@@ -32,6 +32,7 @@
#include <utility>
#include <utility>
#include <unordered_map>
#include <unordered_map>
#include <migraphx/reflect.hpp>
#include <migraphx/reflect.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/normalize_attributes.hpp>
#include <migraphx/normalize_attributes.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/argument.hpp>
...
@@ -46,6 +47,80 @@ inline namespace MIGRAPHX_INLINE_NS {
...
@@ -46,6 +47,80 @@ inline namespace MIGRAPHX_INLINE_NS {
struct
context
;
struct
context
;
struct
dyn_output
{
F
ins_inputs
;
dyn_output
(
F
f
)
:
ins_inputs
(
f
){};
shape
get_input_shape
()
{
if
(
ins_shape
.
empty
())
{
ins_shape
=
unpack
(
[
&
](
const
auto
&
,
shape
s
,
const
std
::
vector
<
argument
>&
)
{
return
s
;
},
ins_inputs
);
}
return
ins_shape
;
}
shape
get_output_shape
()
{
if
(
computed_shape
.
empty
())
{
computed_shape
=
unpack
(
[
&
](
const
auto
&
x
,
shape
,
const
std
::
vector
<
argument
>&
inputs
)
{
return
compute_shape
(
x
,
to_shapes
(
inputs
));
},
ins_inputs
);
}
return
computed_shape
;
}
private:
// original shape from the instruction
shape
ins_shape
;
// shape computed at eval time using input arguments
shape
computed_shape
;
};
/**
* Handle dynamic and static shape at evaluation time.
* If converted to shape type, returns original ins_shape.
* If converted to dyn_output type, will compute an output shape using the input arguments.
*/
template
<
class
F
>
struct
compute_output_shape
{
F
ins_inputs
;
operator
dyn_output
<
F
>
()
const
{
/*
return unpack([](const auto& x, shape ins_shape, const std::vector<argument>& inputs)
{
return dyn_output{ins_shape, compute_shape(x, to_shapes(inputs))};
},
ins_inputs
);
*/
return
dyn_output
<
F
>
{
ins_inputs
};
}
operator
shape
()
const
{
return
unpack
(
[](
const
auto
&
,
shape
ins_shape
,
const
std
::
vector
<
argument
>&
)
{
return
ins_shape
;
},
ins_inputs
);
}
};
template
<
class
T
>
auto
make_compute_output_shape
(
const
T
&
x
,
shape
output_shape
,
const
std
::
vector
<
argument
>&
inputs
)
->
decltype
(
compute_output_shape
{
pack
(
x
,
output_shape
,
inputs
)})
{
return
compute_output_shape
{
pack
(
x
,
output_shape
,
inputs
)};
}
#ifdef DOXYGEN
#ifdef DOXYGEN
/// The operation interface represents an action an instruction will perform. All
/// The operation interface represents an action an instruction will perform. All
...
@@ -199,9 +274,11 @@ auto compute_op(rank<1>,
...
@@ -199,9 +274,11 @@ auto compute_op(rank<1>,
context
&
ctx
,
context
&
ctx
,
const
shape
&
output_shape
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
input
)
const
std
::
vector
<
argument
>&
input
)
->
decltype
(
x
.
compute
(
auto_any_cast
(
ctx
),
output_shape
,
input
))
->
decltype
(
x
.
compute
(
auto_any_cast
(
ctx
),
make_compute_output_shape
(
x
,
output_shape
,
input
),
input
))
{
{
return
x
.
compute
(
auto_any_cast
(
ctx
),
output_shape
,
input
);
return
x
.
compute
(
auto_any_cast
(
ctx
),
make_compute_output_shape
(
x
,
output_shape
,
input
)
,
input
);
}
}
template
<
class
T
>
template
<
class
T
>
...
@@ -220,7 +297,7 @@ compute_op(const T& x, context& ctx, const shape& output_shape, const std::vecto
...
@@ -220,7 +297,7 @@ compute_op(const T& x, context& ctx, const shape& output_shape, const std::vecto
template
<
class
T
>
template
<
class
T
>
auto
compute_op
(
rank
<
1
>
,
const
T
&
x
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
input
)
auto
compute_op
(
rank
<
1
>
,
const
T
&
x
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
input
)
->
decltype
(
x
.
compute
(
output_shape
,
input
))
->
decltype
(
x
.
compute
(
make_compute_output_shape
(
x
,
output_shape
,
input
)
,
input
))
{
{
return
x
.
compute
(
make_compute_output_shape
(
x
,
output_shape
,
input
),
input
);
return
x
.
compute
(
make_compute_output_shape
(
x
,
output_shape
,
input
),
input
);
}
}
...
@@ -244,9 +321,10 @@ auto compute_op(rank<1>,
...
@@ -244,9 +321,10 @@ auto compute_op(rank<1>,
const
shape
&
output
,
const
shape
&
output
,
const
std
::
vector
<
argument
>&
inputs
,
const
std
::
vector
<
argument
>&
inputs
,
const
std
::
vector
<
module_ref
>&
module_args
,
const
std
::
vector
<
module_ref
>&
module_args
,
F
f
)
->
decltype
(
x
.
compute
(
output
,
inputs
,
module_args
,
f
))
F
f
)
->
decltype
(
x
.
compute
(
make_compute_output_shape
(
x
,
output
,
inputs
),
inputs
,
module_args
,
f
))
{
{
return
x
.
compute
(
output
,
inputs
,
module_args
,
f
);
return
x
.
compute
(
make_compute_output_shape
(
x
,
output
,
inputs
)
,
inputs
,
module_args
,
f
);
}
}
template
<
class
T
,
class
F
>
template
<
class
T
,
class
F
>
...
@@ -278,9 +356,12 @@ auto compute_op(rank<4>,
...
@@ -278,9 +356,12 @@ auto compute_op(rank<4>,
const
shape
&
output
,
const
shape
&
output
,
const
std
::
vector
<
argument
>&
inputs
,
const
std
::
vector
<
argument
>&
inputs
,
const
std
::
vector
<
module_ref
>&
module_args
,
const
std
::
vector
<
module_ref
>&
module_args
,
F
f
)
->
decltype
(
x
.
compute
(
auto_any_cast
(
ctx
),
output
,
inputs
,
module_args
,
f
))
F
f
)
->
decltype
(
x
.
compute
(
auto_any_cast
(
ctx
),
make_compute_output_shape
(
x
,
output
,
inputs
),
inputs
,
module_args
,
f
))
{
{
return
x
.
compute
(
auto_any_cast
(
ctx
),
output
,
inputs
,
module_args
,
f
);
return
x
.
compute
(
auto_any_cast
(
ctx
),
make_compute_output_shape
(
x
,
output
,
inputs
),
inputs
,
module_args
,
f
);
}
}
template
<
class
T
,
class
F
>
template
<
class
T
,
class
F
>
...
@@ -290,9 +371,10 @@ auto compute_op(rank<3>,
...
@@ -290,9 +371,10 @@ auto compute_op(rank<3>,
const
shape
&
output
,
const
shape
&
output
,
const
std
::
vector
<
argument
>&
inputs
,
const
std
::
vector
<
argument
>&
inputs
,
const
std
::
vector
<
module_ref
>&
module_args
,
const
std
::
vector
<
module_ref
>&
module_args
,
F
f
)
->
decltype
(
x
.
compute
(
output
,
inputs
,
module_args
,
f
))
F
f
)
->
decltype
(
x
.
compute
(
make_compute_output_shape
(
x
,
output
,
inputs
),
inputs
,
module_args
,
f
))
{
{
return
x
.
compute
(
output
,
inputs
,
module_args
,
f
);
return
x
.
compute
(
make_compute_output_shape
(
x
,
output
,
inputs
)
,
inputs
,
module_args
,
f
);
}
}
template
<
class
T
,
class
F
>
template
<
class
T
,
class
F
>
...
@@ -302,9 +384,9 @@ auto compute_op(rank<2>,
...
@@ -302,9 +384,9 @@ auto compute_op(rank<2>,
const
shape
&
output
,
const
shape
&
output
,
const
std
::
vector
<
argument
>&
inputs
,
const
std
::
vector
<
argument
>&
inputs
,
const
std
::
vector
<
module_ref
>&
,
const
std
::
vector
<
module_ref
>&
,
F
)
->
decltype
(
x
.
compute
(
output
,
inputs
))
F
)
->
decltype
(
x
.
compute
(
make_compute_output_shape
(
x
,
output
,
inputs
)
,
inputs
))
{
{
return
x
.
compute
(
output
,
inputs
);
return
x
.
compute
(
make_compute_output_shape
(
x
,
output
,
inputs
)
,
inputs
);
}
}
template
<
class
T
,
class
F
>
template
<
class
T
,
class
F
>
...
@@ -314,9 +396,10 @@ auto compute_op(rank<1>,
...
@@ -314,9 +396,10 @@ auto compute_op(rank<1>,
const
shape
&
output
,
const
shape
&
output
,
const
std
::
vector
<
argument
>&
inputs
,
const
std
::
vector
<
argument
>&
inputs
,
const
std
::
vector
<
module_ref
>&
,
const
std
::
vector
<
module_ref
>&
,
F
)
->
decltype
(
x
.
compute
(
auto_any_cast
(
ctx
),
output
,
inputs
))
F
)
->
decltype
(
x
.
compute
(
auto_any_cast
(
ctx
),
compute_output_shape
<
T
>
{
x
,
output
,
inputs
},
inputs
))
{
{
return
x
.
compute
(
auto_any_cast
(
ctx
),
output
,
inputs
);
return
x
.
compute
(
auto_any_cast
(
ctx
),
make_compute_output_shape
(
x
,
output
,
inputs
)
,
inputs
);
}
}
template
<
class
T
,
class
F
>
template
<
class
T
,
class
F
>
...
@@ -563,59 +646,6 @@ lifetime get_lifetime_op(const T&)
...
@@ -563,59 +646,6 @@ lifetime get_lifetime_op(const T&)
return
not
(
x
==
y
);
return
not
(
x
==
y
);
}
}
// used for dynamic operators
struct
dyn_output
{
// original instruction output shape
shape
ins_shape
;
std
::
function
<
shape
()
>
compute_shape
;
shape
get_output_shape
()
{
if
(
output_shape
.
element_space
()
==
0
)
{
output_shape
=
compute_shape
();
}
return
output_shape
;
}
private:
// shape computed at eval time using input arguments
shape
output_shape
;
};
/**
* Handle dynamic and static shape at evaluation time.
* If converted to shape type, returns original ins_shape
* If converted to dyn_output type, will compute an output shape using the input arguments
*/
template
<
class
F
>
struct
compute_output_shape
{
F
ins_inputs
;
operator
dyn_output
()
const
{
return
unpack
(
[](
const
auto
&
x
,
shape
ins_shape
,
const
std
::
vector
<
argument
>&
args
)
{
return
dyn_output
{
ins_shape
,
[
&
]()
{
compute_shape
(
x
,
to_shapes
(
args
));
}};
},
ins_inputs
);
}
operator
shape
()
const
{
return
unpack
(
[](
const
auto
&
,
shape
ins_shape
,
const
std
::
vector
<
argument
>&
)
{
return
ins_shape
;
},
ins_inputs
);
}
};
template
<
class
T
>
auto
make_compute_output_shape
(
const
T
&
x
,
shape
ins_shape
,
const
std
::
vector
<
argument
>&
input
)
{
return
compute_output_shape
{
pack
(
x
,
ins_shape
,
input
)};
}
inline
value
inline
value
compile
(
operation
&
op
,
context
&
ctx
,
const
shape
&
output_shape
,
const
std
::
vector
<
shape
>&
input
)
compile
(
operation
&
op
,
context
&
ctx
,
const
shape
&
output_shape
,
const
std
::
vector
<
shape
>&
input
)
{
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment