Commit a1952a8d authored by Paul's avatar Paul
Browse files

Make compute optional

parent f550da30
......@@ -8,6 +8,7 @@
#include <type_traits>
#include <utility>
#include <migraph/shape.hpp>
#include <migraph/rank.hpp>
#include <migraph/argument.hpp>
#include <migraph/context.hpp>
#include <migraph/auto_any_cast.hpp>
......@@ -55,11 +56,28 @@ auto operator<<(std::ostream& os, const T& x) -> decltype(os << x.name())
} // namespace operation_stream
template <class T>
auto compute_op(rank<1>,
const T& x,
context& ctx,
const shape& output_shape,
const std::vector<argument>& input)
-> decltype(x.compute(auto_any_cast(ctx), output_shape, input))
{
return x.compute(auto_any_cast(ctx), output_shape, input);
}
template <class T>
argument compute_op(rank<0>, const T& x, context&, const shape&, const std::vector<argument>&)
{
MIGRAPH_THROW("Not computable: " + x.name());
}
template <class T>
argument
compute_op(const T& x, context& ctx, const shape& output_shape, const std::vector<argument>& input)
{
return x.compute(auto_any_cast(ctx), output_shape, input);
return compute_op(rank<1>{}, x, ctx, output_shape, input);
}
/*
......
......@@ -41,11 +41,6 @@ struct batch_norm_inference
check_shapes{inputs, *this}.has(5);
return inputs.front();
}
argument compute(context&, const shape&, const std::vector<argument>&) const
{
MIGRAPH_THROW("not computable");
}
};
struct convolution
......@@ -115,11 +110,6 @@ struct convolution
}
}
argument compute(context&, const shape&, const std::vector<argument>&) const
{
MIGRAPH_THROW("not computable");
}
friend std::ostream& operator<<(std::ostream& os, const convolution& op)
{
os << op.name() << "[";
......@@ -169,11 +159,6 @@ struct im2col
auto channels_col = kernel_height * kernel_width * input_channels;
return {input.type(), {output_height * output_width, channels_col}};
}
argument compute(context&, const shape&, const std::vector<argument>&) const
{
MIGRAPH_THROW("not computable");
}
};
struct pooling
......@@ -211,11 +196,6 @@ struct pooling
}};
}
argument compute(context&, const shape&, const std::vector<argument>&) const
{
MIGRAPH_THROW("not computable");
}
friend std::ostream& operator<<(std::ostream& os, const pooling& op)
{
os << op.name() << "[";
......@@ -236,11 +216,6 @@ struct activation
check_shapes{inputs, *this}.has(1);
return inputs.front();
}
argument compute(context&, const shape&, const std::vector<argument>&) const
{
MIGRAPH_THROW("not computable");
}
friend std::ostream& operator<<(std::ostream& os, const activation& op)
{
os << op.name() << ":" << op.mode;
......@@ -305,10 +280,6 @@ struct contiguous
}
return {t, lens};
}
argument compute(context&, const shape&, const std::vector<argument>&) const
{
MIGRAPH_THROW("not computable");
}
};
struct reshape
......@@ -349,12 +320,10 @@ struct reshape
MIGRAPH_THROW("Wrong number of elements for reshape");
return s;
}
argument compute(context&, shape output_shape, std::vector<argument> args) const
{
return {std::move(output_shape), std::move(args.front().data)};
}
friend std::ostream& operator<<(std::ostream& os, const reshape& op)
{
os << op.name() << "[";
......@@ -382,11 +351,6 @@ struct gemm
return {t, {a.lens()[0], b.lens()[1]}};
}
argument compute(context&, const shape&, const std::vector<argument>&) const
{
MIGRAPH_THROW("not computable");
}
friend std::ostream& operator<<(std::ostream& os, const gemm& op)
{
os << op.name() << "[";
......@@ -402,10 +366,6 @@ struct unary
check_shapes{inputs}.has(1);
return inputs.at(0);
}
argument compute(context&, const shape&, const std::vector<argument>&) const
{
MIGRAPH_THROW("not computable");
}
};
struct identity : unary
......@@ -553,10 +513,6 @@ struct binary
check_shapes{inputs}.has(2).same_type().same_dims();
return inputs.at(0);
}
argument compute(context&, const shape&, const std::vector<argument>&) const
{
MIGRAPH_THROW("not computable");
}
};
struct add : binary
......
......@@ -3,19 +3,10 @@
#include <algorithm>
#include <initializer_list>
#include <migraph/rank.hpp>
namespace migraph {
template <int N>
struct rank : rank<N - 1>
{
};
template <>
struct rank<0>
{
};
namespace detail {
template <class String, class T>
......
#ifndef MIGRAPH_GUARD_RTGLIB_RANK_HPP
#define MIGRAPH_GUARD_RTGLIB_RANK_HPP
namespace migraph {
template <int N>
struct rank : rank<N - 1>
{
};
template <>
struct rank<0>
{
};
} // namespace migraph
#endif
......@@ -8,6 +8,7 @@
#include <type_traits>
#include <utility>
#include <migraph/shape.hpp>
#include <migraph/rank.hpp>
#include <migraph/argument.hpp>
#include <migraph/context.hpp>
#include <migraph/auto_any_cast.hpp>
......@@ -55,11 +56,26 @@ auto operator<<(std::ostream& os, const T& x) -> decltype(os << x.name())
} // namespace operation_stream
template <class T>
auto
compute_op(rank<1>, const T& x, context& ctx, const shape& output_shape, const std::vector<argument>& input)
-> decltype(x.compute(auto_any_cast(ctx), output_shape, input))
{
return x.compute(auto_any_cast(ctx), output_shape, input);
}
template <class T>
argument
compute_op(rank<0>, const T& x, context&, const shape&, const std::vector<argument>&)
{
MIGRAPH_THROW("Not computable: " + x.name());
}
template <class T>
argument
compute_op(const T& x, context& ctx, const shape& output_shape, const std::vector<argument>& input)
{
return x.compute(auto_any_cast(ctx), output_shape, input);
return compute_op(rank<1>{}, x, ctx, output_shape, input);
}
<%
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment