"...composable_kernel_onnxruntime.git" did not exist on "4a2a56c22f75263d70c710950ab7313f072a2523"
Commit 610aa6a7 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

add a capture operator.

parent 0861bb2a
#ifndef MIGRAPHX_GUARD_OPERATORS_CAPTURE_HPP
#define MIGRAPHX_GUARD_OPERATORS_CATTURE_HPP
#include <array>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct capture
{
std::function<void(std::vector<argument>) f;
template<class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.f, "func"));
}
std::string name() const { return "capputure"; }
shape compute_shape(std::vector<shape> inputs) const
{
return inputs.front();
}
argument compute(const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
args.push_back(result);
f(args);
return result;
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
#include <migraphx/op/batch_norm.hpp> #include <migraphx/op/batch_norm.hpp>
#include <migraphx/op/binary.hpp> #include <migraphx/op/binary.hpp>
#include <migraphx/op/broadcast.hpp> #include <migraphx/op/broadcast.hpp>
#include <migraphx/op/capture.hpp>
#include <migraphx/op/common.hpp> #include <migraphx/op/common.hpp>
#include <migraphx/op/concat.hpp> #include <migraphx/op/concat.hpp>
#include <migraphx/op/contiguous.hpp> #include <migraphx/op/contiguous.hpp>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment