Commit ac230464 authored by Scott Thornton's avatar Scott Thornton
Browse files
parents 9a7c3e30 6f0e001e
...@@ -73,14 +73,14 @@ struct target ...@@ -73,14 +73,14 @@ struct target
std::string name() const std::string name() const
{ {
assert(private_detail_te_handle_mem_var); assert((*this).private_detail_te_handle_mem_var);
return private_detail_te_get_handle().name(); return (*this).private_detail_te_get_handle().name();
} }
void apply(program& p) const void apply(program& p) const
{ {
assert(private_detail_te_handle_mem_var); assert((*this).private_detail_te_handle_mem_var);
return private_detail_te_get_handle().apply(p); return (*this).private_detail_te_get_handle().apply(p);
} }
private: private:
......
...@@ -24,6 +24,11 @@ struct unknown ...@@ -24,6 +24,11 @@ struct unknown
return input.front(); return input.front();
} }
rtg::argument compute(std::vector<rtg::argument>) const { RTG_THROW("not computable"); } rtg::argument compute(std::vector<rtg::argument>) const { RTG_THROW("not computable"); }
friend std::ostream& operator<<(std::ostream& os, const unknown& x)
{
os << x.name();
return os;
}
}; };
template <class C, class T> template <class C, class T>
...@@ -211,6 +216,7 @@ struct onnx_parser ...@@ -211,6 +216,7 @@ struct onnx_parser
case onnx::AttributeProto::TENSORS: return {}; case onnx::AttributeProto::TENSORS: return {};
case onnx::AttributeProto::GRAPHS: return {}; case onnx::AttributeProto::GRAPHS: return {};
} }
RTG_THROW("Invalid attribute type");
} }
static rtg::literal parse_tensor(const onnx::TensorProto& t) static rtg::literal parse_tensor(const onnx::TensorProto& t)
...@@ -251,6 +257,7 @@ struct onnx_parser ...@@ -251,6 +257,7 @@ struct onnx_parser
case onnx::TensorProto::COMPLEX64: throw std::runtime_error(""); case onnx::TensorProto::COMPLEX64: throw std::runtime_error("");
case onnx::TensorProto::COMPLEX128: throw std::runtime_error(""); case onnx::TensorProto::COMPLEX128: throw std::runtime_error("");
} }
RTG_THROW("Invalid tensor type");
} }
static rtg::shape parse_type(const onnx::TypeProto& t) static rtg::shape parse_type(const onnx::TypeProto& t)
......
...@@ -98,9 +98,9 @@ literal program::eval(std::unordered_map<std::string, argument> params) const ...@@ -98,9 +98,9 @@ literal program::eval(std::unordered_map<std::string, argument> params) const
{ {
result = ins.lit.get_argument(); result = ins.lit.get_argument();
} }
else if(starts_with(ins.op.name(), "@param")) else if(ins.op.name() == "@param")
{ {
result = params.at(ins.op.name().substr(7)); result = params.at(any_cast<builtin::param>(ins.op).parameter);
} }
else else
{ {
...@@ -124,14 +124,14 @@ std::ostream& operator<<(std::ostream& os, const program& p) ...@@ -124,14 +124,14 @@ std::ostream& operator<<(std::ostream& os, const program& p)
for(auto& ins : p.impl->instructions) for(auto& ins : p.impl->instructions)
{ {
std::string var_name = "@" + std::to_string(count); std::string var_name = "@" + std::to_string(count);
if(starts_with(ins.op.name(), "@param")) if(ins.op.name() == "@param")
{ {
var_name = ins.op.name().substr(7); var_name = any_cast<builtin::param>(ins.op).parameter;
} }
os << var_name << " = "; os << var_name << " = ";
os << ins.op.name(); os << ins.op;
if(ins.op.name() == "@literal") if(ins.op.name() == "@literal")
{ {
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
#include <rtg/program.hpp> #include <rtg/program.hpp>
#include <rtg/argument.hpp> #include <rtg/argument.hpp>
#include <rtg/shape.hpp> #include <rtg/shape.hpp>
#include <sstream>
#include "test.hpp" #include "test.hpp"
struct sum_op struct sum_op
...@@ -94,6 +95,20 @@ void literal_test2() ...@@ -94,6 +95,20 @@ void literal_test2()
EXPECT(result != rtg::literal{3}); EXPECT(result != rtg::literal{3});
} }
void print_test()
{
rtg::program p;
auto x = p.add_parameter("x", {rtg::shape::int64_type});
auto two = p.add_literal(2);
p.add_instruction(sum_op{}, x, two);
std::stringstream ss;
ss << p;
std::string s = ss.str();
EXPECT(!s.empty());
}
void param_test() void param_test()
{ {
rtg::program p; rtg::program p;
...@@ -156,6 +171,7 @@ int main() ...@@ -156,6 +171,7 @@ int main()
{ {
literal_test1(); literal_test1();
literal_test2(); literal_test2();
print_test();
param_test(); param_test();
replace_test(); replace_test();
insert_replace_test(); insert_replace_test();
......
...@@ -10,6 +10,18 @@ struct simple_operation ...@@ -10,6 +10,18 @@ struct simple_operation
std::string name() const { return "simple"; } std::string name() const { return "simple"; }
rtg::shape compute_shape(std::vector<rtg::shape>) const { RTG_THROW("not computable"); } rtg::shape compute_shape(std::vector<rtg::shape>) const { RTG_THROW("not computable"); }
rtg::argument compute(std::vector<rtg::argument>) const { RTG_THROW("not computable"); } rtg::argument compute(std::vector<rtg::argument>) const { RTG_THROW("not computable"); }
friend std::ostream& operator<<(std::ostream& os, const simple_operation& op)
{
os << "[" << op.name() << "]";
return os;
}
};
struct simple_operation_no_print
{
std::string name() const { return "simple"; }
rtg::shape compute_shape(std::vector<rtg::shape>) const { RTG_THROW("not computable"); }
rtg::argument compute(std::vector<rtg::argument>) const { RTG_THROW("not computable"); }
}; };
void operation_copy_test() void operation_copy_test()
...@@ -36,8 +48,28 @@ void operation_any_cast() ...@@ -36,8 +48,28 @@ void operation_any_cast()
EXPECT(rtg::any_cast<not_operation*>(&op2) == nullptr); EXPECT(rtg::any_cast<not_operation*>(&op2) == nullptr);
} }
void operation_print()
{
rtg::operation op = simple_operation{};
std::stringstream ss;
ss << op;
std::string s = ss.str();
EXPECT(s == "[simple]");
}
void operation_default_print()
{
rtg::operation op = simple_operation_no_print{};
std::stringstream ss;
ss << op;
std::string s = ss.str();
EXPECT(s == "simple");
}
int main() int main()
{ {
operation_copy_test(); operation_copy_test();
operation_any_cast(); operation_any_cast();
operation_print();
operation_default_print();
} }
DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
ls -1 $DIR/include/ | xargs -n 1 -P $(nproc) -I{} -t bash -c "python $DIR/te.py $DIR/include/{} | clang-format-5.0 -style=file > $DIR/../include/rtg/{}" ls -1 $DIR/include/ | xargs -n 1 -P $(nproc) -I{} -t bash -c "python $DIR/te.py $DIR/include/{} | clang-format-5.0 -style=file > $DIR/../src/include/rtg/{}"
...@@ -11,11 +11,22 @@ ...@@ -11,11 +11,22 @@
namespace rtg { namespace rtg {
namespace operation_stream {
template <class T>
auto operator<<(std::ostream& os, const T& x) -> decltype(os << x.name())
{
return os << x.name();
}
} // namespace operation_stream
<% <%
interface('operation', interface('operation',
virtual('name', returns='std::string', const=True), virtual('name', returns='std::string', const=True),
virtual('compute_shape', returns='shape', input='std::vector<shape>', const=True), virtual('compute_shape', returns='shape', input='std::vector<shape>', const=True),
virtual('compute', returns='argument', input='std::vector<argument>', const=True) virtual('compute', returns='argument', input='std::vector<argument>', const=True),
friend('operator<<', returns='std::ostream &', os='std::ostream &', op='const operation &', using='rtg::operation_stream::operator<<')
) )
%> %>
......
...@@ -161,43 +161,109 @@ inline const ValueType & any_cast(const ${struct_name} & x) ...@@ -161,43 +161,109 @@ inline const ValueType & any_cast(const ${struct_name} & x)
''') ''')
nonvirtual_member = string.Template(''' nonvirtual_member = string.Template('''
${return_type} ${name}(${params}) ${const} ${friend} ${return_type} ${name}(${params}) ${const}
{ {
assert(private_detail_te_handle_mem_var); assert(${this}.private_detail_te_handle_mem_var);
return private_detail_te_get_handle().${name}(${args}); return ${this}.private_detail_te_get_handle().${internal_name}(${member_args});
} }
''') ''')
pure_virtual_member = string.Template("virtual ${return_type} ${name}(${params}) ${const} = 0;\n") pure_virtual_member = string.Template("virtual ${return_type} ${internal_name}(${member_params}) ${member_const} = 0;\n")
virtual_member = string.Template(''' virtual_member = string.Template('''
${return_type} ${name}(${params}) ${const} override ${return_type} ${internal_name}(${member_params}) ${member_const} override
{ {
return private_detail_te_value.${name}(${args}); ${using}
return ${call};
} }
''') ''')
comment_member = string.Template('''* ${return_type} ${name}(${params}) ${const};''') comment_member = string.Template('''* ${friend} ${return_type} ${name}(${params}) ${const};''')
def convert_member(d): def trim_type_name(name):
n = name.strip()
if n.startswith('const'):
return trim_type_name(n[5:])
if n.endswith(('&', '*')):
return trim_type_name(n[0:-1])
return n
def internal_name(name):
internal_names = {
'operator<<': 'operator_shift_left',
'operator>>': 'operator_shift_right',
}
if name in internal_names:
return internal_names[name]
else:
return name
def generate_call(m, friend):
if m['name'].startswith('operator'):
op = m['name'][8:]
args = m['args']
if ',' in args:
return args.replace(',', op)
else:
return string.Template('${op}${arga}').substitute(op=op, args=args)
if friend:
return string.Template('${name}(${args})').substitute(m)
return string.Template('private_detail_te_value.${name}(${args})').substitute(m)
def convert_member(d, struct_name):
for name in d: for name in d:
member = { 'name': name, 'const': ''} member = {
'name': name,
'internal_name': internal_name(name),
'const': '',
'member_const': '',
'friend': '',
'this': '(*this)',
'using': ''
}
args = [] args = []
params = [] params = []
member_args = []
member_params = []
skip = False
friend = False
if 'friend' in d[name]:
friend = True
skip = True
for x in d[name]: for x in d[name]:
t = d[name][x] t = d[name][x]
if x == 'return': if x == 'return':
member['return_type'] = t member['return_type'] = t
elif x == 'const': elif x == 'const':
member['const'] = 'const' member['const'] = 'const'
member['member_const'] = 'const'
elif x == 'friend':
member['friend'] = 'friend'
elif x == 'using':
member['using'] = 'using {};'.format(d[name]['using'])
else: else:
use_member = not(skip and struct_name == trim_type_name(t))
arg_name = x
if not use_member:
arg_name = 'private_detail_te_value'
member['this'] = x
if 'const' in t:
member['member_const'] = 'const'
if t.endswith(('&', '*')): if t.endswith(('&', '*')):
args.append(x) if use_member: member_args.append(x)
args.append(arg_name)
else: else:
args.append('std::move({})'.format(x)) if use_member: member_args.append('std::move({})'.format(x))
args.append('std::move({})'.format(arg_name))
params.append(t+' '+x) params.append(t+' '+x)
if use_member: member_params.append(t+' '+x)
else: skip = False
member['args'] = ','.join(args) member['args'] = ','.join(args)
member['member_args'] = ','.join(member_args)
member['params'] = ','.join(params) member['params'] = ','.join(params)
member['params'] = ','.join(params)
member['member_params'] = ','.join(member_params)
member['call'] = generate_call(member, friend)
return member return member
return None return None
...@@ -208,7 +274,7 @@ def generate_form(name, members): ...@@ -208,7 +274,7 @@ def generate_form(name, members):
virtual_members = [] virtual_members = []
comment_members = [] comment_members = []
for member in members: for member in members:
m = convert_member(member) m = convert_member(member, name)
nonvirtual_members.append(nonvirtual_member.substitute(m)) nonvirtual_members.append(nonvirtual_member.substitute(m))
pure_virtual_members.append(pure_virtual_member.substitute(m)) pure_virtual_members.append(pure_virtual_member.substitute(m))
virtual_members.append(virtual_member.substitute(m)) virtual_members.append(virtual_member.substitute(m))
...@@ -226,6 +292,12 @@ def virtual(name, returns=None, **kwargs): ...@@ -226,6 +292,12 @@ def virtual(name, returns=None, **kwargs):
args['return'] = returns args['return'] = returns
return { name: args } return { name: args }
def friend(name, returns=None, **kwargs):
args = kwargs
args['return'] = returns
args['friend'] = 'friend'
return { name: args }
def interface(name, *members): def interface(name, *members):
return generate_form(name, members) return generate_form(name, members)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment