Commit 1c8a45ee authored by Paul's avatar Paul
Browse files

Add eval function to instruction

parent 46b3e7da
......@@ -71,6 +71,8 @@ struct instruction
static void
replace(instruction_ref ins, operation o, const shape& r, std::vector<instruction_ref> args);
argument eval() const;
static instruction_ref get_output_alias(instruction_ref ins);
private:
......
......@@ -53,6 +53,9 @@ struct operation
friend std::ostream& operator<<(std::ostream& os, const operation& op);
};
/// Returns true if operation does not require a context to run compute
bool is_context_free(const operation& x);
#else
namespace operation_stream {
......@@ -89,7 +92,7 @@ auto operator==(const T& x, const U& y) -> decltype(x.name() == y.name())
} // namespace operation_equal
template <class T>
auto compute_op(rank<1>,
auto compute_op(rank<2>,
const T& x,
context& ctx,
const shape& output_shape,
......@@ -99,6 +102,14 @@ auto compute_op(rank<1>,
return x.compute(auto_any_cast(ctx), output_shape, input);
}
template <class T>
auto compute_op(
rank<1>, const T& x, context&, const shape& output_shape, const std::vector<argument>& input)
-> decltype(x.compute(output_shape, input))
{
return x.compute(output_shape, input);
}
template <class T>
argument compute_op(rank<0>, const T& x, context&, const shape&, const std::vector<argument>&)
{
......@@ -110,7 +121,53 @@ template <class T>
argument
compute_op(const T& x, context& ctx, const shape& output_shape, const std::vector<argument>& input)
{
return compute_op(rank<1>{}, x, ctx, output_shape, input);
return compute_op(rank<2>{}, x, ctx, output_shape, input);
}
template <class T>
auto compute_op(rank<2>, const T& x, const shape& output_shape, const std::vector<argument>& input)
-> decltype(x.compute(output_shape, input))
{
return x.compute(output_shape, input);
}
template <class T>
auto compute_op(rank<1>, const T& x, const shape& output_shape, const std::vector<argument>& input)
-> decltype(x.compute(auto_any_cast(std::declval<context&>()), output_shape, input))
{
std::string name = x.name();
MIGRAPHX_THROW("Not computable without a context: " + name);
}
template <class T>
argument compute_op(rank<0>, const T& x, const shape&, const std::vector<argument>&)
{
std::string name = x.name();
MIGRAPHX_THROW("Not computable: " + name);
}
template <class T>
argument compute_op(const T& x, const shape& output_shape, const std::vector<argument>& input)
{
return compute_op(rank<2>{}, x, output_shape, input);
}
template <class T>
auto is_context_free_op(rank<1>,
const T& x,
const shape& output_shape,
const std::vector<argument>& input)
-> decltype(x.compute(output_shape, input), std::true_type{});
template <class T>
auto is_context_free_op(rank<0>, const T&, const shape&, const std::vector<argument>&)
-> std::false_type;
template <class T>
auto is_context_free_op(const T& x) -> decltype(is_context_free_op(
rank<1>{}, x, std::declval<const shape&>(), std::declval<std::vector<argument>>()))
{
return {};
}
template <class T>
......@@ -138,9 +195,11 @@ int output_alias_op(const T& x, const std::vector<shape>& shapes)
* struct operation
* {
* std::string name() const;
* bool is_context_free() const;
* int output_alias(const std::vector<shape>& input) const;
* shape compute_shape(const std::vector<shape>& input) const;
* argument compute(context& ctx,const shape& output,const std::vector<argument>& input) const;
* argument compute(const shape& output,const std::vector<argument>& input) const;
* friend std::ostream & operator<<(std::ostream & os,const operation & op) ;
* friend bool operator==(const operation & x,const operation & y) ;
* };
......@@ -210,6 +269,12 @@ struct operation
return (*this).private_detail_te_get_handle().name();
}
bool is_context_free() const
{
assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().is_context_free();
}
int output_alias(const std::vector<shape>& input) const
{
assert((*this).private_detail_te_handle_mem_var);
......@@ -228,6 +293,12 @@ struct operation
return (*this).private_detail_te_get_handle().compute(ctx, output, input);
}
argument compute(const shape& output, const std::vector<argument>& input) const
{
assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().compute(output, input);
}
friend std::ostream& operator<<(std::ostream& os, const operation& op)
{
assert(op.private_detail_te_handle_mem_var);
......@@ -248,10 +319,12 @@ struct operation
virtual const std::type_info& type() const = 0;
virtual std::string name() const = 0;
virtual bool is_context_free() const = 0;
virtual int output_alias(const std::vector<shape>& input) const = 0;
virtual shape compute_shape(const std::vector<shape>& input) const = 0;
virtual argument
compute(context& ctx, const shape& output, const std::vector<argument>& input) const = 0;
virtual argument compute(const shape& output, const std::vector<argument>& input) const = 0;
virtual std::ostream& operator_shift_left(std::ostream& os) const = 0;
virtual bool operator==(const operation& y) const = 0;
};
......@@ -286,6 +359,12 @@ struct operation
std::string name() const override { return private_detail_te_value.name(); }
bool is_context_free() const override
{
return is_context_free_op(private_detail_te_value);
}
int output_alias(const std::vector<shape>& input) const override
{
......@@ -306,6 +385,12 @@ struct operation
return compute_op(private_detail_te_value, ctx, output, input);
}
argument compute(const shape& output, const std::vector<argument>& input) const override
{
return compute_op(private_detail_te_value, output, input);
}
std::ostream& operator_shift_left(std::ostream& os) const override
{
using migraphx::operation_stream::operator<<;
......@@ -385,6 +470,14 @@ inline const ValueType& any_cast(const operation& x)
inline bool operator!=(const operation& x, const operation& y) { return !(x == y); }
inline bool is_context_free(const operation& op) { return op.is_context_free(); }
template <class T>
bool is_context_free(const T& x)
{
return is_context_free_op(x);
}
#endif
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -16,7 +16,7 @@ namespace op {
struct not_computable
{
argument compute(context&, const shape&, const std::vector<argument>&) const
argument compute(const shape&, const std::vector<argument>&) const
{
MIGRAPHX_THROW("not computable");
}
......@@ -296,7 +296,7 @@ struct transpose
}
return {t, output_lens, output_strides};
}
argument compute(context&, shape output_shape, std::vector<argument> args) const
argument compute(shape output_shape, std::vector<argument> args) const
{
return {std::move(output_shape), std::move(args.front().data)};
}
......@@ -437,7 +437,7 @@ struct slice
}
return shape{t, new_lens, old_strides};
}
argument compute(context&, shape output_shape, std::vector<argument> args) const
argument compute(shape output_shape, std::vector<argument> args) const
{
auto input = args[0];
auto offset = compute_offset(input.get_shape()) * output_shape.type_size();
......@@ -487,7 +487,7 @@ struct squeeze
}
return shape{type, new_lens};
}
argument compute(context&, shape output_shape, std::vector<argument> args) const
argument compute(shape output_shape, std::vector<argument> args) const
{
return {std::move(output_shape), std::move(args.front().data)};
}
......@@ -526,7 +526,7 @@ struct unsqueeze
}
return shape{type, new_lens};
}
argument compute(context&, shape output_shape, std::vector<argument> args) const
argument compute(shape output_shape, std::vector<argument> args) const
{
return {std::move(output_shape), std::move(args.front().data)};
}
......@@ -578,7 +578,7 @@ struct reshape
MIGRAPHX_THROW("Wrong number of elements for reshape");
return s;
}
argument compute(context&, shape output_shape, std::vector<argument> args) const
argument compute(shape output_shape, std::vector<argument> args) const
{
return {std::move(output_shape), std::move(args.front().data)};
}
......@@ -624,7 +624,7 @@ struct identity
{
std::string name() const { return "identity"; }
shape compute_shape(std::vector<shape> inputs) const { return inputs.at(0); }
argument compute(context&, shape output_shape, std::vector<argument> args) const
argument compute(shape output_shape, std::vector<argument> args) const
{
return {std::move(output_shape), std::move(args.at(0).data)};
}
......@@ -742,7 +742,7 @@ struct flatten
std::accumulate(lens.begin() + axis, lens.end(), std::size_t{1}, std::multiplies<>{});
return {inputs.at(0).type(), {x, y}};
}
argument compute(context&, shape output_shape, std::vector<argument> args) const
argument compute(shape output_shape, std::vector<argument> args) const
{
return {std::move(output_shape), std::move(args.front().data)};
}
......@@ -794,7 +794,7 @@ struct broadcast
return {t, broadcast_shape.lens(), std::move(bcast_strides)};
}
}
argument compute(context&, shape output_shape, std::vector<argument> args) const
argument compute(shape output_shape, std::vector<argument> args) const
{
return {std::move(output_shape), std::move(args.at(0).data)};
}
......@@ -836,7 +836,7 @@ struct multibroadcast
}
return {t, output_lens, bcast_strides};
}
argument compute(context&, shape output_shape, std::vector<argument> args) const
argument compute(shape output_shape, std::vector<argument> args) const
{
return {std::move(output_shape), std::move(args.at(0).data)};
}
......@@ -858,7 +858,7 @@ struct scalar
return {t, scalar_bcast.lens(), strides};
}
argument compute(context&, shape output_shape, std::vector<argument> args) const
argument compute(shape output_shape, std::vector<argument> args) const
{
return {std::move(output_shape), std::move(args.at(0).data)};
}
......@@ -923,7 +923,7 @@ struct load
check_shapes{inputs}.has(1);
return s;
}
argument compute(context&, const shape&, const std::vector<argument>& args) const
argument compute(const shape&, const std::vector<argument>& args) const
{
return {s, args[0].data() + offset};
}
......@@ -946,7 +946,7 @@ struct outline
check_shapes{inputs, *this}.has(0);
return s;
}
argument compute(context&, const shape&, const std::vector<argument>&) const
argument compute(const shape&, const std::vector<argument>&) const
{
return {s, nullptr};
}
......
......@@ -170,6 +170,26 @@ std::vector<shape> compute_shapes(const std::vector<instruction_ref>& args)
return shapes;
}
argument instruction::eval() const
{
if(op.name() == "@literal")
{
return this->get_literal().get_argument();
}
if(is_context_free(op))
{
std::vector<argument> args;
for(auto&& arg : this->inputs())
{
argument a = arg->eval();
if(a.empty()) return {};
args.push_back(a);
}
return op.compute(result, args);
}
return {};
}
instruction_ref instruction::get_output_alias(instruction_ref ins)
{
auto i = ins->get_operator().output_alias(compute_shapes(ins->inputs()));
......
......@@ -53,6 +53,9 @@ struct operation
friend std::ostream& operator<<(std::ostream& os, const operation& op);
};
/// Returns true if operation does not require a context to run compute
bool is_context_free(const operation& x);
#else
namespace operation_stream {
......@@ -89,7 +92,7 @@ auto operator==(const T& x, const U& y) -> decltype(x.name() == y.name())
} // namespace operation_equal
template <class T>
auto compute_op(rank<1>,
auto compute_op(rank<2>,
const T& x,
context& ctx,
const shape& output_shape,
......@@ -99,6 +102,17 @@ auto compute_op(rank<1>,
return x.compute(auto_any_cast(ctx), output_shape, input);
}
template <class T>
auto compute_op(rank<1>,
const T& x,
context&,
const shape& output_shape,
const std::vector<argument>& input)
-> decltype(x.compute(output_shape, input))
{
return x.compute(output_shape, input);
}
template <class T>
argument compute_op(rank<0>, const T& x, context&, const shape&, const std::vector<argument>&)
{
......@@ -110,7 +124,54 @@ template <class T>
argument
compute_op(const T& x, context& ctx, const shape& output_shape, const std::vector<argument>& input)
{
return compute_op(rank<1>{}, x, ctx, output_shape, input);
return compute_op(rank<2>{}, x, ctx, output_shape, input);
}
template <class T>
auto compute_op(rank<2>,
const T& x,
const shape& output_shape,
const std::vector<argument>& input)
-> decltype(x.compute(output_shape, input))
{
return x.compute(output_shape, input);
}
template <class T>
auto compute_op(rank<1>,
const T& x,
const shape& output_shape,
const std::vector<argument>& input)
-> decltype(x.compute(auto_any_cast(std::declval<context&>()), output_shape, input))
{
std::string name = x.name();
MIGRAPHX_THROW("Not computable without a context: " + name);
}
template <class T>
argument compute_op(rank<0>, const T& x, const shape&, const std::vector<argument>&)
{
std::string name = x.name();
MIGRAPHX_THROW("Not computable: " + name);
}
template <class T>
argument
compute_op(const T& x, const shape& output_shape, const std::vector<argument>& input)
{
return compute_op(rank<2>{}, x, output_shape, input);
}
template <class T>
auto is_context_free_op(rank<1>, const T& x, const shape& output_shape, const std::vector<argument>& input) -> decltype(x.compute(output_shape, input), std::true_type{});
template <class T>
auto is_context_free_op(rank<0>, const T&, const shape&, const std::vector<argument>&) -> std::false_type;
template <class T>
auto is_context_free_op(const T& x) -> decltype(is_context_free_op(rank<1>{}, x, std::declval<const shape&>(), std::declval<std::vector<argument>>()))
{
return {};
}
template <class T>
......@@ -136,6 +197,7 @@ int output_alias_op(const T& x, const std::vector<shape>& shapes)
interface(
'operation',
virtual('name', returns = 'std::string', const = True),
virtual('is_context_free', returns = 'bool', const = True, default = 'is_context_free_op'),
virtual('output_alias',
returns = 'int',
input = 'const std::vector<shape>&',
......@@ -149,6 +211,12 @@ int output_alias_op(const T& x, const std::vector<shape>& shapes)
input = 'const std::vector<argument>&',
const = True,
default = 'compute_op'),
virtual('compute',
returns = 'argument',
output = 'const shape&',
input = 'const std::vector<argument>&',
const = True,
default = 'compute_op'),
friend('operator<<',
returns = 'std::ostream &',
os = 'std::ostream &',
......@@ -165,6 +233,17 @@ int output_alias_op(const T& x, const std::vector<shape>& shapes)
return !(x == y);
}
inline bool is_context_free(const operation& op)
{
return op.is_context_free();
}
template<class T>
bool is_context_free(const T& x)
{
return is_context_free_op(x);
}
#endif
} // namespace MIGRAPHX_INLINE_NS
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment