Commit d21778c6 authored by Paul's avatar Paul
Browse files

Add shape param to compute

parent 0f318650
...@@ -12,7 +12,7 @@ struct literal ...@@ -12,7 +12,7 @@ struct literal
{ {
std::string name() const { return "@literal"; } std::string name() const { return "@literal"; }
shape compute_shape(std::vector<shape>) const { RTG_THROW("builtin"); } shape compute_shape(std::vector<shape>) const { RTG_THROW("builtin"); }
argument compute(std::vector<argument>) const { RTG_THROW("builtin"); } argument compute(shape, std::vector<argument>) const { RTG_THROW("builtin"); }
}; };
struct param struct param
...@@ -20,7 +20,7 @@ struct param ...@@ -20,7 +20,7 @@ struct param
std::string parameter; std::string parameter;
std::string name() const { return "@param"; } std::string name() const { return "@param"; }
shape compute_shape(std::vector<shape>) const { RTG_THROW("builtin"); } shape compute_shape(std::vector<shape>) const { RTG_THROW("builtin"); }
argument compute(std::vector<argument>) const { RTG_THROW("builtin"); } argument compute(shape, std::vector<argument>) const { RTG_THROW("builtin"); }
friend std::ostream& operator<<(std::ostream& os, const param& op) friend std::ostream& operator<<(std::ostream& os, const param& op)
{ {
os << op.name() << ":" << op.parameter; os << op.name() << ":" << op.parameter;
......
...@@ -28,7 +28,7 @@ auto operator<<(std::ostream& os, const T& x) -> decltype(os << x.name()) ...@@ -28,7 +28,7 @@ auto operator<<(std::ostream& os, const T& x) -> decltype(os << x.name())
* { * {
* std::string name() const; * std::string name() const;
* shape compute_shape(std::vector<shape> input) const; * shape compute_shape(std::vector<shape> input) const;
* argument compute(std::vector<argument> input) const; * argument compute(shape output,std::vector<argument> input) const;
* friend std::ostream & operator<<(std::ostream & os,const operation & op) ; * friend std::ostream & operator<<(std::ostream & os,const operation & op) ;
* }; * };
* *
...@@ -95,10 +95,10 @@ struct operation ...@@ -95,10 +95,10 @@ struct operation
return (*this).private_detail_te_get_handle().compute_shape(std::move(input)); return (*this).private_detail_te_get_handle().compute_shape(std::move(input));
} }
argument compute(std::vector<argument> input) const argument compute(shape output, std::vector<argument> input) const
{ {
assert((*this).private_detail_te_handle_mem_var); assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().compute(std::move(input)); return (*this).private_detail_te_get_handle().compute(std::move(output), std::move(input));
} }
friend std::ostream& operator<<(std::ostream& os, const operation& op) friend std::ostream& operator<<(std::ostream& os, const operation& op)
...@@ -114,10 +114,10 @@ struct operation ...@@ -114,10 +114,10 @@ struct operation
virtual std::shared_ptr<private_detail_te_handle_base_type> clone() const = 0; virtual std::shared_ptr<private_detail_te_handle_base_type> clone() const = 0;
virtual const std::type_info& type() const = 0; virtual const std::type_info& type() const = 0;
virtual std::string name() const = 0; virtual std::string name() const = 0;
virtual shape compute_shape(std::vector<shape> input) const = 0; virtual shape compute_shape(std::vector<shape> input) const = 0;
virtual argument compute(std::vector<argument> input) const = 0; virtual argument compute(shape output, std::vector<argument> input) const = 0;
virtual std::ostream& operator_shift_left(std::ostream& os) const = 0; virtual std::ostream& operator_shift_left(std::ostream& os) const = 0;
}; };
template <typename PrivateDetailTypeErasedT> template <typename PrivateDetailTypeErasedT>
...@@ -156,10 +156,10 @@ struct operation ...@@ -156,10 +156,10 @@ struct operation
return private_detail_te_value.compute_shape(std::move(input)); return private_detail_te_value.compute_shape(std::move(input));
} }
argument compute(std::vector<argument> input) const override argument compute(shape output, std::vector<argument> input) const override
{ {
return private_detail_te_value.compute(std::move(input)); return private_detail_te_value.compute(std::move(output), std::move(input));
} }
std::ostream& operator_shift_left(std::ostream& os) const override std::ostream& operator_shift_left(std::ostream& os) const override
......
...@@ -10,7 +10,7 @@ namespace rtg { ...@@ -10,7 +10,7 @@ namespace rtg {
struct not_computable struct not_computable
{ {
argument compute(std::vector<argument>) const { RTG_THROW("not computable"); } argument compute(shape, std::vector<argument>) const { RTG_THROW("not computable"); }
}; };
struct convolution struct convolution
...@@ -52,7 +52,7 @@ struct convolution ...@@ -52,7 +52,7 @@ struct convolution
}}; }};
} }
argument compute(std::vector<argument>) const { RTG_THROW("not computable"); } argument compute(shape, std::vector<argument>) const { RTG_THROW("not computable"); }
friend std::ostream& operator<<(std::ostream& os, const convolution& op) friend std::ostream& operator<<(std::ostream& os, const convolution& op)
{ {
...@@ -98,7 +98,7 @@ struct pooling ...@@ -98,7 +98,7 @@ struct pooling
}}; }};
} }
argument compute(std::vector<argument>) const { RTG_THROW("not computable"); } argument compute(shape, std::vector<argument>) const { RTG_THROW("not computable"); }
friend std::ostream& operator<<(std::ostream& os, const pooling& op) friend std::ostream& operator<<(std::ostream& os, const pooling& op)
{ {
...@@ -122,7 +122,7 @@ struct activation ...@@ -122,7 +122,7 @@ struct activation
return inputs.front(); return inputs.front();
} }
argument compute(std::vector<argument>) const { RTG_THROW("not computable"); } argument compute(shape, std::vector<argument>) const { RTG_THROW("not computable"); }
friend std::ostream& operator<<(std::ostream& os, const activation& op) friend std::ostream& operator<<(std::ostream& os, const activation& op)
{ {
os << op.name() << ":" << op.mode; os << op.name() << ":" << op.mode;
...@@ -153,7 +153,7 @@ struct reshape ...@@ -153,7 +153,7 @@ struct reshape
return {inputs.front().type(), rdims}; return {inputs.front().type(), rdims};
} }
argument compute(std::vector<argument>) const { RTG_THROW("not computable"); } argument compute(shape, std::vector<argument>) const { RTG_THROW("not computable"); }
friend std::ostream& operator<<(std::ostream& os, const reshape& op) friend std::ostream& operator<<(std::ostream& os, const reshape& op)
{ {
......
...@@ -22,7 +22,7 @@ struct unknown ...@@ -22,7 +22,7 @@ struct unknown
else else
return input.front(); return input.front();
} }
rtg::argument compute(std::vector<rtg::argument>) const { RTG_THROW("not computable"); } rtg::argument compute(rtg::shape, std::vector<rtg::argument>) const { RTG_THROW("not computable"); }
friend std::ostream& operator<<(std::ostream& os, const unknown& x) friend std::ostream& operator<<(std::ostream& os, const unknown& x)
{ {
os << x.name(); os << x.name();
......
...@@ -109,7 +109,7 @@ literal program::eval(std::unordered_map<std::string, argument> params) const ...@@ -109,7 +109,7 @@ literal program::eval(std::unordered_map<std::string, argument> params) const
ins.arguments.end(), ins.arguments.end(),
values.begin(), values.begin(),
[&](instruction_ref i) { return results.at(std::addressof(*i)); }); [&](instruction_ref i) { return results.at(std::addressof(*i)); });
result = ins.op.compute(values); result = ins.op.compute(ins.result, values);
} }
results.emplace(std::addressof(ins), result); results.emplace(std::addressof(ins), result);
} }
......
...@@ -13,10 +13,9 @@ struct cpu_convolution ...@@ -13,10 +13,9 @@ struct cpu_convolution
std::string name() const { return "cpu::convolution"; } std::string name() const { return "cpu::convolution"; }
shape compute_shape(std::vector<shape> inputs) const { return op.compute_shape(inputs); } shape compute_shape(std::vector<shape> inputs) const { return op.compute_shape(inputs); }
argument compute(std::vector<argument> args) const argument compute(shape output_shape, std::vector<argument> args) const
{ {
shape output_shape = compute_shape({args[0].get_shape(), args[1].get_shape()}); argument result{output_shape};
argument result{compute_shape({args[0].get_shape(), args[1].get_shape()})};
visit_all(result, args[0], args[1])([&](auto output, auto input, auto weights) { visit_all(result, args[0], args[1])([&](auto output, auto input, auto weights) {
auto in_n = input.get_shape().lens()[0]; auto in_n = input.get_shape().lens()[0];
auto in_c = input.get_shape().lens()[1]; auto in_c = input.get_shape().lens()[1];
...@@ -53,9 +52,9 @@ struct relu ...@@ -53,9 +52,9 @@ struct relu
std::string name() const { return "cpu::relu"; } std::string name() const { return "cpu::relu"; }
shape compute_shape(std::vector<shape> inputs) const { return inputs.front(); } shape compute_shape(std::vector<shape> inputs) const { return inputs.front(); }
argument compute(std::vector<argument> args) const argument compute(shape output_shape, std::vector<argument> args) const
{ {
argument result{args[0].get_shape()}; argument result{output_shape};
result.visit([&](auto output) { result.visit([&](auto output) {
args[0].visit([&](auto input) { args[0].visit([&](auto input) {
std::transform(input.begin(), input.end(), output.begin(), [](auto x) { std::transform(input.begin(), input.end(), output.begin(), [](auto x) {
......
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
struct sum_op struct sum_op
{ {
std::string name() const { return "sum"; } std::string name() const { return "sum"; }
rtg::argument compute(std::vector<rtg::argument> args) const rtg::argument compute(rtg::shape, std::vector<rtg::argument> args) const
{ {
rtg::argument result; rtg::argument result;
if(args.size() != 2) if(args.size() != 2)
...@@ -37,7 +37,7 @@ struct sum_op ...@@ -37,7 +37,7 @@ struct sum_op
struct minus_op struct minus_op
{ {
std::string name() const { return "minus"; } std::string name() const { return "minus"; }
rtg::argument compute(std::vector<rtg::argument> args) const rtg::argument compute(rtg::shape, std::vector<rtg::argument> args) const
{ {
rtg::argument result; rtg::argument result;
if(args.size() != 2) if(args.size() != 2)
......
...@@ -9,7 +9,7 @@ struct simple_operation ...@@ -9,7 +9,7 @@ struct simple_operation
int data = 1; int data = 1;
std::string name() const { return "simple"; } std::string name() const { return "simple"; }
rtg::shape compute_shape(std::vector<rtg::shape>) const { RTG_THROW("not computable"); } rtg::shape compute_shape(std::vector<rtg::shape>) const { RTG_THROW("not computable"); }
rtg::argument compute(std::vector<rtg::argument>) const { RTG_THROW("not computable"); } rtg::argument compute(rtg::shape, std::vector<rtg::argument>) const { RTG_THROW("not computable"); }
friend std::ostream& operator<<(std::ostream& os, const simple_operation& op) friend std::ostream& operator<<(std::ostream& os, const simple_operation& op)
{ {
os << "[" << op.name() << "]"; os << "[" << op.name() << "]";
...@@ -21,7 +21,7 @@ struct simple_operation_no_print ...@@ -21,7 +21,7 @@ struct simple_operation_no_print
{ {
std::string name() const { return "simple"; } std::string name() const { return "simple"; }
rtg::shape compute_shape(std::vector<rtg::shape>) const { RTG_THROW("not computable"); } rtg::shape compute_shape(std::vector<rtg::shape>) const { RTG_THROW("not computable"); }
rtg::argument compute(std::vector<rtg::argument>) const { RTG_THROW("not computable"); } rtg::argument compute(rtg::shape, std::vector<rtg::argument>) const { RTG_THROW("not computable"); }
}; };
void operation_copy_test() void operation_copy_test()
......
...@@ -25,7 +25,7 @@ auto operator<<(std::ostream& os, const T& x) -> decltype(os << x.name()) ...@@ -25,7 +25,7 @@ auto operator<<(std::ostream& os, const T& x) -> decltype(os << x.name())
interface('operation', interface('operation',
virtual('name', returns='std::string', const=True), virtual('name', returns='std::string', const=True),
virtual('compute_shape', returns='shape', input='std::vector<shape>', const=True), virtual('compute_shape', returns='shape', input='std::vector<shape>', const=True),
virtual('compute', returns='argument', input='std::vector<argument>', const=True), virtual('compute', returns='argument', output='shape', input='std::vector<argument>', const=True),
friend('operator<<', returns='std::ostream &', os='std::ostream &', op='const operation &', using='rtg::operation_stream::operator<<') friend('operator<<', returns='std::ostream &', os='std::ostream &', op='const operation &', using='rtg::operation_stream::operator<<')
) )
%> %>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment