Commit cdea8d86 authored by Paul's avatar Paul
Browse files

Add stream operator

parent ab0ea297
...@@ -23,6 +23,11 @@ struct unknown ...@@ -23,6 +23,11 @@ struct unknown
return input.front(); return input.front();
} }
rtg::argument compute(std::vector<rtg::argument>) const { RTG_THROW("not computable"); } rtg::argument compute(std::vector<rtg::argument>) const { RTG_THROW("not computable"); }
friend std::ostream & operator<<(std::ostream & os, const unknown & x)
{
os << x.name();
return os;
}
}; };
template <class C, class T> template <class C, class T>
......
...@@ -13,6 +13,11 @@ struct literal ...@@ -13,6 +13,11 @@ struct literal
std::string name() const { return "@literal"; } std::string name() const { return "@literal"; }
shape compute_shape(std::vector<shape>) const { RTG_THROW("builtin"); } shape compute_shape(std::vector<shape>) const { RTG_THROW("builtin"); }
argument compute(std::vector<argument>) const { RTG_THROW("builtin"); } argument compute(std::vector<argument>) const { RTG_THROW("builtin"); }
friend std::ostream & operator<<(std::ostream & os, const literal & op)
{
os << op.name();
return os;
}
}; };
struct param struct param
...@@ -21,6 +26,11 @@ struct param ...@@ -21,6 +26,11 @@ struct param
std::string name() const { return "@param:" + parameter; } std::string name() const { return "@param:" + parameter; }
shape compute_shape(std::vector<shape>) const { RTG_THROW("builtin"); } shape compute_shape(std::vector<shape>) const { RTG_THROW("builtin"); }
argument compute(std::vector<argument>) const { RTG_THROW("builtin"); } argument compute(std::vector<argument>) const { RTG_THROW("builtin"); }
friend std::ostream & operator<<(std::ostream & os, const param & op)
{
os << op.name();
return os;
}
}; };
} // namespace builtin } // namespace builtin
......
...@@ -19,6 +19,7 @@ namespace rtg { ...@@ -19,6 +19,7 @@ namespace rtg {
* std::string name() const; * std::string name() const;
* shape compute_shape(std::vector<shape> input) const; * shape compute_shape(std::vector<shape> input) const;
* argument compute(std::vector<argument> input) const; * argument compute(std::vector<argument> input) const;
* friend std::ostream & operator<<(std::ostream & os,const operation & op) ;
* }; * };
* *
*/ */
...@@ -74,20 +75,26 @@ struct operation ...@@ -74,20 +75,26 @@ struct operation
std::string name() const std::string name() const
{ {
assert(private_detail_te_handle_mem_var); assert((*this).private_detail_te_handle_mem_var);
return private_detail_te_get_handle().name(); return (*this).private_detail_te_get_handle().name();
} }
shape compute_shape(std::vector<shape> input) const shape compute_shape(std::vector<shape> input) const
{ {
assert(private_detail_te_handle_mem_var); assert((*this).private_detail_te_handle_mem_var);
return private_detail_te_get_handle().compute_shape(std::move(input)); return (*this).private_detail_te_get_handle().compute_shape(std::move(input));
} }
argument compute(std::vector<argument> input) const argument compute(std::vector<argument> input) const
{ {
assert(private_detail_te_handle_mem_var); assert((*this).private_detail_te_handle_mem_var);
return private_detail_te_get_handle().compute(std::move(input)); return (*this).private_detail_te_get_handle().compute(std::move(input));
}
friend std::ostream& operator<<(std::ostream& os, const operation& op)
{
assert(op.private_detail_te_handle_mem_var);
return op.private_detail_te_get_handle().operator_shift_left(os);
} }
private: private:
...@@ -100,6 +107,7 @@ struct operation ...@@ -100,6 +107,7 @@ struct operation
virtual std::string name() const = 0; virtual std::string name() const = 0;
virtual shape compute_shape(std::vector<shape> input) const = 0; virtual shape compute_shape(std::vector<shape> input) const = 0;
virtual argument compute(std::vector<argument> input) const = 0; virtual argument compute(std::vector<argument> input) const = 0;
virtual std::ostream& operator_shift_left(std::ostream& os) const = 0;
}; };
template <typename PrivateDetailTypeErasedT> template <typename PrivateDetailTypeErasedT>
...@@ -142,6 +150,11 @@ struct operation ...@@ -142,6 +150,11 @@ struct operation
return private_detail_te_value.compute(std::move(input)); return private_detail_te_value.compute(std::move(input));
} }
std::ostream& operator_shift_left(std::ostream& os) const override
{
return os << private_detail_te_value;
}
PrivateDetailTypeErasedT private_detail_te_value; PrivateDetailTypeErasedT private_detail_te_value;
}; };
......
...@@ -56,6 +56,12 @@ struct convolution ...@@ -56,6 +56,12 @@ struct convolution
} }
argument compute(std::vector<argument>) const { RTG_THROW("not computable"); } argument compute(std::vector<argument>) const { RTG_THROW("not computable"); }
friend std::ostream & operator<<(std::ostream & os, const convolution & op)
{
os << op.name();
return os;
}
}; };
struct pooling struct pooling
...@@ -96,6 +102,12 @@ struct pooling ...@@ -96,6 +102,12 @@ struct pooling
} }
argument compute(std::vector<argument>) const { RTG_THROW("not computable"); } argument compute(std::vector<argument>) const { RTG_THROW("not computable"); }
friend std::ostream & operator<<(std::ostream & os, const pooling & op)
{
os << op.name();
return os;
}
}; };
struct activation struct activation
...@@ -110,6 +122,11 @@ struct activation ...@@ -110,6 +122,11 @@ struct activation
} }
argument compute(std::vector<argument>) const { RTG_THROW("not computable"); } argument compute(std::vector<argument>) const { RTG_THROW("not computable"); }
friend std::ostream & operator<<(std::ostream & os, const activation & op)
{
os << op.name();
return os;
}
}; };
struct reshape struct reshape
...@@ -136,6 +153,12 @@ struct reshape ...@@ -136,6 +153,12 @@ struct reshape
} }
argument compute(std::vector<argument>) const { RTG_THROW("not computable"); } argument compute(std::vector<argument>) const { RTG_THROW("not computable"); }
friend std::ostream & operator<<(std::ostream & os, const reshape & op)
{
os << op.name();
return os;
}
}; };
} // namespace rtg } // namespace rtg
......
...@@ -73,14 +73,14 @@ struct target ...@@ -73,14 +73,14 @@ struct target
std::string name() const std::string name() const
{ {
assert(private_detail_te_handle_mem_var); assert((*this).private_detail_te_handle_mem_var);
return private_detail_te_get_handle().name(); return (*this).private_detail_te_get_handle().name();
} }
void apply(program& p) const void apply(program& p) const
{ {
assert(private_detail_te_handle_mem_var); assert((*this).private_detail_te_handle_mem_var);
return private_detail_te_get_handle().apply(p); return (*this).private_detail_te_get_handle().apply(p);
} }
private: private:
......
...@@ -31,6 +31,12 @@ struct sum_op ...@@ -31,6 +31,12 @@ struct sum_op
RTG_THROW("Wrong inputs"); RTG_THROW("Wrong inputs");
return inputs.front(); return inputs.front();
} }
friend std::ostream & operator<<(std::ostream & os, const sum_op & op)
{
os << op.name();
return os;
}
}; };
struct minus_op struct minus_op
...@@ -60,6 +66,12 @@ struct minus_op ...@@ -60,6 +66,12 @@ struct minus_op
RTG_THROW("Wrong inputs"); RTG_THROW("Wrong inputs");
return inputs.front(); return inputs.front();
} }
friend std::ostream & operator<<(std::ostream & os, const minus_op & op)
{
os << op.name();
return os;
}
}; };
struct id_target struct id_target
......
...@@ -10,6 +10,11 @@ struct simple_operation ...@@ -10,6 +10,11 @@ struct simple_operation
std::string name() const { return "simple"; } std::string name() const { return "simple"; }
rtg::shape compute_shape(std::vector<rtg::shape>) const { RTG_THROW("not computable"); } rtg::shape compute_shape(std::vector<rtg::shape>) const { RTG_THROW("not computable"); }
rtg::argument compute(std::vector<rtg::argument>) const { RTG_THROW("not computable"); } rtg::argument compute(std::vector<rtg::argument>) const { RTG_THROW("not computable"); }
friend std::ostream & operator<<(std::ostream & os, const simple_operation & op)
{
os << op.name();
return os;
}
}; };
void operation_copy_test() void operation_copy_test()
......
...@@ -15,7 +15,8 @@ namespace rtg { ...@@ -15,7 +15,8 @@ namespace rtg {
interface('operation', interface('operation',
virtual('name', returns='std::string', const=True), virtual('name', returns='std::string', const=True),
virtual('compute_shape', returns='shape', input='std::vector<shape>', const=True), virtual('compute_shape', returns='shape', input='std::vector<shape>', const=True),
virtual('compute', returns='argument', input='std::vector<argument>', const=True) virtual('compute', returns='argument', input='std::vector<argument>', const=True),
friend('operator<<', returns='std::ostream &', os='std::ostream &', op='const operation &')
) )
%> %>
......
...@@ -163,15 +163,15 @@ inline const ValueType & any_cast(const ${struct_name} & x) ...@@ -163,15 +163,15 @@ inline const ValueType & any_cast(const ${struct_name} & x)
nonvirtual_member = string.Template(''' nonvirtual_member = string.Template('''
${friend} ${return_type} ${name}(${params}) ${const} ${friend} ${return_type} ${name}(${params}) ${const}
{ {
assert(private_detail_te_handle_mem_var); assert(${this}.private_detail_te_handle_mem_var);
return private_detail_te_get_handle().${internal_name}(${member_args}); return ${this}.private_detail_te_get_handle().${internal_name}(${member_args});
} }
''') ''')
pure_virtual_member = string.Template("virtual ${return_type} ${internal_name}(${member_params}) ${const} = 0;\n") pure_virtual_member = string.Template("virtual ${return_type} ${internal_name}(${member_params}) ${member_const} = 0;\n")
virtual_member = string.Template(''' virtual_member = string.Template('''
${return_type} ${internal_name}(${member_params}) ${const} override ${return_type} ${internal_name}(${member_params}) ${member_const} override
{ {
return ${call}; return ${call};
} }
...@@ -201,17 +201,24 @@ def generate_call(m, friend): ...@@ -201,17 +201,24 @@ def generate_call(m, friend):
if m['name'].startswith('operator'): if m['name'].startswith('operator'):
op = m['name'][8:] op = m['name'][8:]
args = m['args'] args = m['args']
if len(m[args]) == 2: if ',' in args:
return string.Template('${arg1} ${op} ${arg2}').substitute(op=op, arg1=args[0], arg2=args[1]) return args.replace(',', op)
else: else:
return string.Template('${op}${arg1}').substitute(op=op, arg1=args[0]) return string.Template('${op}${arga}').substitute(op=op, args=args)
if friend: if friend:
return string.Template('${name}(${args})').substitute(m) return string.Template('${name}(${args})').substitute(m)
return string.Template('private_detail_te_value.${name}(${args})').substitute(m) return string.Template('private_detail_te_value.${name}(${args})').substitute(m)
def convert_member(d, struct_name): def convert_member(d, struct_name):
for name in d: for name in d:
member = { 'name': name, 'internal_name': internal_name(name), 'const': '', 'friend': ''} member = {
'name': name,
'internal_name': internal_name(name),
'const': '',
'member_const': '',
'friend': '',
'this': '(*this)'
}
args = [] args = []
params = [] params = []
member_args = [] member_args = []
...@@ -227,12 +234,17 @@ def convert_member(d, struct_name): ...@@ -227,12 +234,17 @@ def convert_member(d, struct_name):
member['return_type'] = t member['return_type'] = t
elif x == 'const': elif x == 'const':
member['const'] = 'const' member['const'] = 'const'
member['member_const'] = 'const'
elif x == 'friend': elif x == 'friend':
member['friend'] = 'friend' member['friend'] = 'friend'
else: else:
use_member = not(skip and struct_name == trim_type_name(t)) use_member = not(skip and struct_name == trim_type_name(t))
arg_name = x arg_name = x
if not use_member: arg_name = 'private_detail_te_value' if not use_member:
arg_name = 'private_detail_te_value'
member['this'] = x
if 'const' in t:
member['member_const'] = 'const'
if t.endswith(('&', '*')): if t.endswith(('&', '*')):
if use_member: member_args.append(x) if use_member: member_args.append(x)
args.append(arg_name) args.append(arg_name)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment