Commit d21778c6 authored by Paul's avatar Paul
Browse files

Add shape param to compute

parent 0f318650
......@@ -12,7 +12,7 @@ struct literal
{
std::string name() const { return "@literal"; }
shape compute_shape(std::vector<shape>) const { RTG_THROW("builtin"); }
argument compute(std::vector<argument>) const { RTG_THROW("builtin"); }
argument compute(shape, std::vector<argument>) const { RTG_THROW("builtin"); }
};
struct param
......@@ -20,7 +20,7 @@ struct param
std::string parameter;
std::string name() const { return "@param"; }
shape compute_shape(std::vector<shape>) const { RTG_THROW("builtin"); }
argument compute(std::vector<argument>) const { RTG_THROW("builtin"); }
argument compute(shape, std::vector<argument>) const { RTG_THROW("builtin"); }
friend std::ostream& operator<<(std::ostream& os, const param& op)
{
os << op.name() << ":" << op.parameter;
......
......@@ -28,7 +28,7 @@ auto operator<<(std::ostream& os, const T& x) -> decltype(os << x.name())
* {
* std::string name() const;
* shape compute_shape(std::vector<shape> input) const;
* argument compute(std::vector<argument> input) const;
* argument compute(shape output,std::vector<argument> input) const;
* friend std::ostream & operator<<(std::ostream & os,const operation & op) ;
* };
*
......@@ -95,10 +95,10 @@ struct operation
return (*this).private_detail_te_get_handle().compute_shape(std::move(input));
}
argument compute(std::vector<argument> input) const
argument compute(shape output, std::vector<argument> input) const
{
assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().compute(std::move(input));
return (*this).private_detail_te_get_handle().compute(std::move(output), std::move(input));
}
friend std::ostream& operator<<(std::ostream& os, const operation& op)
......@@ -114,10 +114,10 @@ struct operation
virtual std::shared_ptr<private_detail_te_handle_base_type> clone() const = 0;
virtual const std::type_info& type() const = 0;
virtual std::string name() const = 0;
virtual shape compute_shape(std::vector<shape> input) const = 0;
virtual argument compute(std::vector<argument> input) const = 0;
virtual std::ostream& operator_shift_left(std::ostream& os) const = 0;
virtual std::string name() const = 0;
virtual shape compute_shape(std::vector<shape> input) const = 0;
virtual argument compute(shape output, std::vector<argument> input) const = 0;
virtual std::ostream& operator_shift_left(std::ostream& os) const = 0;
};
template <typename PrivateDetailTypeErasedT>
......@@ -156,10 +156,10 @@ struct operation
return private_detail_te_value.compute_shape(std::move(input));
}
argument compute(std::vector<argument> input) const override
argument compute(shape output, std::vector<argument> input) const override
{
return private_detail_te_value.compute(std::move(input));
return private_detail_te_value.compute(std::move(output), std::move(input));
}
std::ostream& operator_shift_left(std::ostream& os) const override
......
......@@ -10,7 +10,7 @@ namespace rtg {
struct not_computable
{
argument compute(std::vector<argument>) const { RTG_THROW("not computable"); }
argument compute(shape, std::vector<argument>) const { RTG_THROW("not computable"); }
};
struct convolution
......@@ -52,7 +52,7 @@ struct convolution
}};
}
argument compute(std::vector<argument>) const { RTG_THROW("not computable"); }
argument compute(shape, std::vector<argument>) const { RTG_THROW("not computable"); }
friend std::ostream& operator<<(std::ostream& os, const convolution& op)
{
......@@ -98,7 +98,7 @@ struct pooling
}};
}
argument compute(std::vector<argument>) const { RTG_THROW("not computable"); }
argument compute(shape, std::vector<argument>) const { RTG_THROW("not computable"); }
friend std::ostream& operator<<(std::ostream& os, const pooling& op)
{
......@@ -122,7 +122,7 @@ struct activation
return inputs.front();
}
argument compute(std::vector<argument>) const { RTG_THROW("not computable"); }
argument compute(shape, std::vector<argument>) const { RTG_THROW("not computable"); }
friend std::ostream& operator<<(std::ostream& os, const activation& op)
{
os << op.name() << ":" << op.mode;
......@@ -153,7 +153,7 @@ struct reshape
return {inputs.front().type(), rdims};
}
argument compute(std::vector<argument>) const { RTG_THROW("not computable"); }
argument compute(shape, std::vector<argument>) const { RTG_THROW("not computable"); }
friend std::ostream& operator<<(std::ostream& os, const reshape& op)
{
......
......@@ -22,7 +22,7 @@ struct unknown
else
return input.front();
}
rtg::argument compute(std::vector<rtg::argument>) const { RTG_THROW("not computable"); }
rtg::argument compute(rtg::shape, std::vector<rtg::argument>) const { RTG_THROW("not computable"); }
friend std::ostream& operator<<(std::ostream& os, const unknown& x)
{
os << x.name();
......
......@@ -109,7 +109,7 @@ literal program::eval(std::unordered_map<std::string, argument> params) const
ins.arguments.end(),
values.begin(),
[&](instruction_ref i) { return results.at(std::addressof(*i)); });
result = ins.op.compute(values);
result = ins.op.compute(ins.result, values);
}
results.emplace(std::addressof(ins), result);
}
......
......@@ -13,10 +13,9 @@ struct cpu_convolution
std::string name() const { return "cpu::convolution"; }
shape compute_shape(std::vector<shape> inputs) const { return op.compute_shape(inputs); }
argument compute(std::vector<argument> args) const
argument compute(shape output_shape, std::vector<argument> args) const
{
shape output_shape = compute_shape({args[0].get_shape(), args[1].get_shape()});
argument result{compute_shape({args[0].get_shape(), args[1].get_shape()})};
argument result{output_shape};
visit_all(result, args[0], args[1])([&](auto output, auto input, auto weights) {
auto in_n = input.get_shape().lens()[0];
auto in_c = input.get_shape().lens()[1];
......@@ -53,9 +52,9 @@ struct relu
std::string name() const { return "cpu::relu"; }
shape compute_shape(std::vector<shape> inputs) const { return inputs.front(); }
argument compute(std::vector<argument> args) const
argument compute(shape output_shape, std::vector<argument> args) const
{
argument result{args[0].get_shape()};
argument result{output_shape};
result.visit([&](auto output) {
args[0].visit([&](auto input) {
std::transform(input.begin(), input.end(), output.begin(), [](auto x) {
......
......@@ -8,7 +8,7 @@
struct sum_op
{
std::string name() const { return "sum"; }
rtg::argument compute(std::vector<rtg::argument> args) const
rtg::argument compute(rtg::shape, std::vector<rtg::argument> args) const
{
rtg::argument result;
if(args.size() != 2)
......@@ -37,7 +37,7 @@ struct sum_op
struct minus_op
{
std::string name() const { return "minus"; }
rtg::argument compute(std::vector<rtg::argument> args) const
rtg::argument compute(rtg::shape, std::vector<rtg::argument> args) const
{
rtg::argument result;
if(args.size() != 2)
......
......@@ -9,7 +9,7 @@ struct simple_operation
int data = 1;
std::string name() const { return "simple"; }
rtg::shape compute_shape(std::vector<rtg::shape>) const { RTG_THROW("not computable"); }
rtg::argument compute(std::vector<rtg::argument>) const { RTG_THROW("not computable"); }
rtg::argument compute(rtg::shape, std::vector<rtg::argument>) const { RTG_THROW("not computable"); }
friend std::ostream& operator<<(std::ostream& os, const simple_operation& op)
{
os << "[" << op.name() << "]";
......@@ -21,7 +21,7 @@ struct simple_operation_no_print
{
std::string name() const { return "simple"; }
rtg::shape compute_shape(std::vector<rtg::shape>) const { RTG_THROW("not computable"); }
rtg::argument compute(std::vector<rtg::argument>) const { RTG_THROW("not computable"); }
rtg::argument compute(rtg::shape, std::vector<rtg::argument>) const { RTG_THROW("not computable"); }
};
void operation_copy_test()
......
......@@ -25,7 +25,7 @@ auto operator<<(std::ostream& os, const T& x) -> decltype(os << x.name())
interface('operation',
virtual('name', returns='std::string', const=True),
virtual('compute_shape', returns='shape', input='std::vector<shape>', const=True),
virtual('compute', returns='argument', input='std::vector<argument>', const=True),
virtual('compute', returns='argument', output='shape', input='std::vector<argument>', const=True),
friend('operator<<', returns='std::ostream &', os='std::ostream &', op='const operation &', using='rtg::operation_stream::operator<<')
)
%>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment