Commit a1952a8d authored by Paul's avatar Paul
Browse files

Make compute optional

parent f550da30
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#include <type_traits> #include <type_traits>
#include <utility> #include <utility>
#include <migraph/shape.hpp> #include <migraph/shape.hpp>
#include <migraph/rank.hpp>
#include <migraph/argument.hpp> #include <migraph/argument.hpp>
#include <migraph/context.hpp> #include <migraph/context.hpp>
#include <migraph/auto_any_cast.hpp> #include <migraph/auto_any_cast.hpp>
...@@ -55,11 +56,28 @@ auto operator<<(std::ostream& os, const T& x) -> decltype(os << x.name()) ...@@ -55,11 +56,28 @@ auto operator<<(std::ostream& os, const T& x) -> decltype(os << x.name())
} // namespace operation_stream } // namespace operation_stream
template <class T>
auto compute_op(rank<1>,
const T& x,
context& ctx,
const shape& output_shape,
const std::vector<argument>& input)
-> decltype(x.compute(auto_any_cast(ctx), output_shape, input))
{
return x.compute(auto_any_cast(ctx), output_shape, input);
}
template <class T>
argument compute_op(rank<0>, const T& x, context&, const shape&, const std::vector<argument>&)
{
MIGRAPH_THROW("Not computable: " + x.name());
}
template <class T> template <class T>
argument argument
compute_op(const T& x, context& ctx, const shape& output_shape, const std::vector<argument>& input) compute_op(const T& x, context& ctx, const shape& output_shape, const std::vector<argument>& input)
{ {
return x.compute(auto_any_cast(ctx), output_shape, input); return compute_op(rank<1>{}, x, ctx, output_shape, input);
} }
/* /*
......
...@@ -41,11 +41,6 @@ struct batch_norm_inference ...@@ -41,11 +41,6 @@ struct batch_norm_inference
check_shapes{inputs, *this}.has(5); check_shapes{inputs, *this}.has(5);
return inputs.front(); return inputs.front();
} }
argument compute(context&, const shape&, const std::vector<argument>&) const
{
MIGRAPH_THROW("not computable");
}
}; };
struct convolution struct convolution
...@@ -115,11 +110,6 @@ struct convolution ...@@ -115,11 +110,6 @@ struct convolution
} }
} }
argument compute(context&, const shape&, const std::vector<argument>&) const
{
MIGRAPH_THROW("not computable");
}
friend std::ostream& operator<<(std::ostream& os, const convolution& op) friend std::ostream& operator<<(std::ostream& os, const convolution& op)
{ {
os << op.name() << "["; os << op.name() << "[";
...@@ -169,11 +159,6 @@ struct im2col ...@@ -169,11 +159,6 @@ struct im2col
auto channels_col = kernel_height * kernel_width * input_channels; auto channels_col = kernel_height * kernel_width * input_channels;
return {input.type(), {output_height * output_width, channels_col}}; return {input.type(), {output_height * output_width, channels_col}};
} }
argument compute(context&, const shape&, const std::vector<argument>&) const
{
MIGRAPH_THROW("not computable");
}
}; };
struct pooling struct pooling
...@@ -211,11 +196,6 @@ struct pooling ...@@ -211,11 +196,6 @@ struct pooling
}}; }};
} }
argument compute(context&, const shape&, const std::vector<argument>&) const
{
MIGRAPH_THROW("not computable");
}
friend std::ostream& operator<<(std::ostream& os, const pooling& op) friend std::ostream& operator<<(std::ostream& os, const pooling& op)
{ {
os << op.name() << "["; os << op.name() << "[";
...@@ -236,11 +216,6 @@ struct activation ...@@ -236,11 +216,6 @@ struct activation
check_shapes{inputs, *this}.has(1); check_shapes{inputs, *this}.has(1);
return inputs.front(); return inputs.front();
} }
argument compute(context&, const shape&, const std::vector<argument>&) const
{
MIGRAPH_THROW("not computable");
}
friend std::ostream& operator<<(std::ostream& os, const activation& op) friend std::ostream& operator<<(std::ostream& os, const activation& op)
{ {
os << op.name() << ":" << op.mode; os << op.name() << ":" << op.mode;
...@@ -305,10 +280,6 @@ struct contiguous ...@@ -305,10 +280,6 @@ struct contiguous
} }
return {t, lens}; return {t, lens};
} }
argument compute(context&, const shape&, const std::vector<argument>&) const
{
MIGRAPH_THROW("not computable");
}
}; };
struct reshape struct reshape
...@@ -349,12 +320,10 @@ struct reshape ...@@ -349,12 +320,10 @@ struct reshape
MIGRAPH_THROW("Wrong number of elements for reshape"); MIGRAPH_THROW("Wrong number of elements for reshape");
return s; return s;
} }
argument compute(context&, shape output_shape, std::vector<argument> args) const argument compute(context&, shape output_shape, std::vector<argument> args) const
{ {
return {std::move(output_shape), std::move(args.front().data)}; return {std::move(output_shape), std::move(args.front().data)};
} }
friend std::ostream& operator<<(std::ostream& os, const reshape& op) friend std::ostream& operator<<(std::ostream& os, const reshape& op)
{ {
os << op.name() << "["; os << op.name() << "[";
...@@ -382,11 +351,6 @@ struct gemm ...@@ -382,11 +351,6 @@ struct gemm
return {t, {a.lens()[0], b.lens()[1]}}; return {t, {a.lens()[0], b.lens()[1]}};
} }
argument compute(context&, const shape&, const std::vector<argument>&) const
{
MIGRAPH_THROW("not computable");
}
friend std::ostream& operator<<(std::ostream& os, const gemm& op) friend std::ostream& operator<<(std::ostream& os, const gemm& op)
{ {
os << op.name() << "["; os << op.name() << "[";
...@@ -402,10 +366,6 @@ struct unary ...@@ -402,10 +366,6 @@ struct unary
check_shapes{inputs}.has(1); check_shapes{inputs}.has(1);
return inputs.at(0); return inputs.at(0);
} }
argument compute(context&, const shape&, const std::vector<argument>&) const
{
MIGRAPH_THROW("not computable");
}
}; };
struct identity : unary struct identity : unary
...@@ -553,10 +513,6 @@ struct binary ...@@ -553,10 +513,6 @@ struct binary
check_shapes{inputs}.has(2).same_type().same_dims(); check_shapes{inputs}.has(2).same_type().same_dims();
return inputs.at(0); return inputs.at(0);
} }
argument compute(context&, const shape&, const std::vector<argument>&) const
{
MIGRAPH_THROW("not computable");
}
}; };
struct add : binary struct add : binary
......
...@@ -3,19 +3,10 @@ ...@@ -3,19 +3,10 @@
#include <algorithm> #include <algorithm>
#include <initializer_list> #include <initializer_list>
#include <migraph/rank.hpp>
namespace migraph { namespace migraph {
template <int N>
struct rank : rank<N - 1>
{
};
template <>
struct rank<0>
{
};
namespace detail { namespace detail {
template <class String, class T> template <class String, class T>
......
#ifndef MIGRAPH_GUARD_RTGLIB_RANK_HPP
#define MIGRAPH_GUARD_RTGLIB_RANK_HPP
namespace migraph {
template <int N>
struct rank : rank<N - 1>
{
};
template <>
struct rank<0>
{
};
} // namespace migraph
#endif
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#include <type_traits> #include <type_traits>
#include <utility> #include <utility>
#include <migraph/shape.hpp> #include <migraph/shape.hpp>
#include <migraph/rank.hpp>
#include <migraph/argument.hpp> #include <migraph/argument.hpp>
#include <migraph/context.hpp> #include <migraph/context.hpp>
#include <migraph/auto_any_cast.hpp> #include <migraph/auto_any_cast.hpp>
...@@ -55,11 +56,26 @@ auto operator<<(std::ostream& os, const T& x) -> decltype(os << x.name()) ...@@ -55,11 +56,26 @@ auto operator<<(std::ostream& os, const T& x) -> decltype(os << x.name())
} // namespace operation_stream } // namespace operation_stream
template <class T>
auto
compute_op(rank<1>, const T& x, context& ctx, const shape& output_shape, const std::vector<argument>& input)
-> decltype(x.compute(auto_any_cast(ctx), output_shape, input))
{
return x.compute(auto_any_cast(ctx), output_shape, input);
}
template <class T>
argument
compute_op(rank<0>, const T& x, context&, const shape&, const std::vector<argument>&)
{
MIGRAPH_THROW("Not computable: " + x.name());
}
template <class T> template <class T>
argument argument
compute_op(const T& x, context& ctx, const shape& output_shape, const std::vector<argument>& input) compute_op(const T& x, context& ctx, const shape& output_shape, const std::vector<argument>& input)
{ {
return x.compute(auto_any_cast(ctx), output_shape, input); return compute_op(rank<1>{}, x, ctx, output_shape, input);
} }
<% <%
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment