diff --git a/src/include/migraphx/op/capture.hpp b/src/include/migraphx/op/capture.hpp index f33eab9bb..80ffcbe6b 100644 --- a/src/include/migraphx/op/capture.hpp +++ b/src/include/migraphx/op/capture.hpp @@ -8,6 +8,7 @@ #include #include #include +#include #include #include @@ -29,7 +30,9 @@ struct capture shape compute_shape(std::vector inputs) const { return inputs.front(); } - argument compute(const shape&, std::vector args) const + argument compute(const shape&, std::vector args) const { return args.front(); } + + argument compute(context&, const shape&, const std::vector& args) const { if(f) { diff --git a/src/include/migraphx/operation.hpp b/src/include/migraphx/operation.hpp index 922eabd67..56108a871 100644 --- a/src/include/migraphx/operation.hpp +++ b/src/include/migraphx/operation.hpp @@ -271,25 +271,25 @@ auto compute_op(rank<3>, template auto compute_op(rank<2>, const T& x, - context&, + context& ctx, const shape& output, const std::vector& inputs, const std::vector&, - F) -> decltype(x.compute(output, inputs)) + F) -> decltype(x.compute(auto_any_cast(ctx), output, inputs)) { - return x.compute(output, inputs); + return x.compute(auto_any_cast(ctx), output, inputs); } template auto compute_op(rank<1>, const T& x, - context& ctx, + context&, const shape& output, const std::vector& inputs, const std::vector&, - F) -> decltype(x.compute(auto_any_cast(ctx), output, inputs)) + F) -> decltype(x.compute(output, inputs)) { - return x.compute(auto_any_cast(ctx), output, inputs); + return x.compute(output, inputs); } template diff --git a/tools/include/operation.hpp b/tools/include/operation.hpp index 0c49edfaf..ef9927cdc 100644 --- a/tools/include/operation.hpp +++ b/tools/include/operation.hpp @@ -271,25 +271,25 @@ auto compute_op(rank<3>, template auto compute_op(rank<2>, const T& x, - context&, + context& ctx, const shape& output, const std::vector& inputs, const std::vector&, - F) -> decltype(x.compute(output, inputs)) + F) -> decltype(x.compute(auto_any_cast(ctx), output, inputs)) { - return x.compute(output, inputs); + return x.compute(auto_any_cast(ctx), output, inputs); } template auto compute_op(rank<1>, const T& x, - context& ctx, + context&, const shape& output, const std::vector& inputs, const std::vector&, - F) -> decltype(x.compute(auto_any_cast(ctx), output, inputs)) + F) -> decltype(x.compute(output, inputs)) { - return x.compute(auto_any_cast(ctx), output, inputs); + return x.compute(output, inputs); } template