"driver/conv_driver.cpp" did not exist on "d075adf12642815a0755823b4d268766a6c2346c"
Commit f320a3da authored by Paul's avatar Paul
Browse files

Auto cast context

parent 29448044
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
#include <migraph/shape.hpp> #include <migraph/shape.hpp>
#include <migraph/argument.hpp> #include <migraph/argument.hpp>
#include <migraph/context.hpp> #include <migraph/context.hpp>
#include <migraph/auto_any_cast.hpp>
namespace migraph { namespace migraph {
...@@ -22,6 +23,12 @@ auto operator<<(std::ostream& os, const T& x) -> decltype(os << x.name()) ...@@ -22,6 +23,12 @@ auto operator<<(std::ostream& os, const T& x) -> decltype(os << x.name())
} // namespace operation_stream } // namespace operation_stream
template <class T>
argument compute_op(const T& x, context& ctx, shape output_shape, std::vector<argument> input)
{
return x.compute(auto_any_cast(ctx), output_shape, input);
}
/* /*
* Type-erased interface for: * Type-erased interface for:
* *
...@@ -169,7 +176,7 @@ struct operation ...@@ -169,7 +176,7 @@ struct operation
argument compute(context& ctx, shape output, std::vector<argument> input) const override argument compute(context& ctx, shape output, std::vector<argument> input) const override
{ {
return private_detail_te_value.compute(ctx, std::move(output), std::move(input)); return compute_op(private_detail_te_value, ctx, std::move(output), std::move(input));
} }
std::ostream& operator_shift_left(std::ostream& os) const override std::ostream& operator_shift_left(std::ostream& os) const override
......
...@@ -25,9 +25,8 @@ struct miopen_convolution ...@@ -25,9 +25,8 @@ struct miopen_convolution
check_shapes{inputs, *this}.has(3); check_shapes{inputs, *this}.has(3);
return op.compute_shape({inputs.at(0), inputs.at(1)}); return op.compute_shape({inputs.at(0), inputs.at(1)});
} }
argument compute(migraph::context& gctx, shape output_shape, std::vector<argument> args) const argument compute(miopen_context& ctx, shape output_shape, std::vector<argument> args) const
{ {
auto& ctx = any_cast<miopen_context>(gctx);
auto x_desc = make_tensor(args[0].get_shape()); auto x_desc = make_tensor(args[0].get_shape());
auto w_desc = make_tensor(args[1].get_shape()); auto w_desc = make_tensor(args[1].get_shape());
auto y_desc = make_tensor(output_shape); auto y_desc = make_tensor(output_shape);
...@@ -77,9 +76,8 @@ struct miopen_pooling ...@@ -77,9 +76,8 @@ struct miopen_pooling
check_shapes{inputs, *this}.has(2); check_shapes{inputs, *this}.has(2);
return op.compute_shape({inputs.at(1)}); return op.compute_shape({inputs.at(1)});
} }
argument compute(migraph::context& gctx, shape output_shape, std::vector<argument> args) const argument compute(miopen_context& ctx, shape output_shape, std::vector<argument> args) const
{ {
auto& ctx = any_cast<miopen_context>(gctx);
auto x_desc = make_tensor(args[0].get_shape()); auto x_desc = make_tensor(args[0].get_shape());
auto y_desc = make_tensor(output_shape); auto y_desc = make_tensor(output_shape);
...@@ -110,7 +108,7 @@ struct miopen_add ...@@ -110,7 +108,7 @@ struct miopen_add
return inputs.at(0); return inputs.at(0);
} }
argument compute(migraph::context& gctx, shape output_shape, std::vector<argument> args) const argument compute(miopen_context& ctx, shape output_shape, std::vector<argument> args) const
{ {
if(args[1].get_shape().broadcasted()) if(args[1].get_shape().broadcasted())
{ {
...@@ -127,7 +125,6 @@ struct miopen_add ...@@ -127,7 +125,6 @@ struct miopen_add
} }
else else
{ {
auto& ctx = any_cast<miopen_context>(gctx);
float alpha = 1, beta = 0; float alpha = 1, beta = 0;
auto a_desc = make_tensor(args[0].get_shape()); auto a_desc = make_tensor(args[0].get_shape());
auto b_desc = make_tensor(args[1].get_shape()); auto b_desc = make_tensor(args[1].get_shape());
...@@ -157,9 +154,8 @@ struct miopen_gemm ...@@ -157,9 +154,8 @@ struct miopen_gemm
check_shapes{inputs, *this}.has(3); check_shapes{inputs, *this}.has(3);
return op.compute_shape({inputs.at(0), inputs.at(1)}); return op.compute_shape({inputs.at(0), inputs.at(1)});
} }
argument compute(migraph::context& gctx, shape output_shape, std::vector<argument> args) const argument compute(miopen_context& ctx, shape output_shape, std::vector<argument> args) const
{ {
auto& ctx = any_cast<miopen_context>(gctx);
float alpha = 1.0f; float alpha = 1.0f;
float beta = 0.0f; float beta = 0.0f;
rocblas_int lda = args[0].get_shape().lens()[1]; rocblas_int lda = args[0].get_shape().lens()[1];
...@@ -196,9 +192,8 @@ struct miopen_relu ...@@ -196,9 +192,8 @@ struct miopen_relu
return inputs.at(1); return inputs.at(1);
} }
argument compute(migraph::context& gctx, shape output_shape, std::vector<argument> args) const argument compute(miopen_context& ctx, shape output_shape, std::vector<argument> args) const
{ {
auto& ctx = any_cast<miopen_context>(gctx);
float alpha = 1, beta = 0; float alpha = 1, beta = 0;
auto x_desc = make_tensor(args[0].get_shape()); auto x_desc = make_tensor(args[0].get_shape());
auto y_desc = make_tensor(output_shape); auto y_desc = make_tensor(output_shape);
......
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
#include <migraph/shape.hpp> #include <migraph/shape.hpp>
#include <migraph/argument.hpp> #include <migraph/argument.hpp>
#include <migraph/context.hpp> #include <migraph/context.hpp>
#include <migraph/auto_any_cast.hpp>
namespace migraph { namespace migraph {
...@@ -22,11 +23,17 @@ auto operator<<(std::ostream& os, const T& x) -> decltype(os << x.name()) ...@@ -22,11 +23,17 @@ auto operator<<(std::ostream& os, const T& x) -> decltype(os << x.name())
} // namespace operation_stream } // namespace operation_stream
template<class T>
argument compute_op(const T& x, context& ctx, shape output_shape, std::vector<argument> input)
{
return x.compute(auto_any_cast(ctx), output_shape, input);
}
<% <%
interface('operation', interface('operation',
virtual('name', returns='std::string', const=True), virtual('name', returns='std::string', const=True),
virtual('compute_shape', returns='shape', input='std::vector<shape>', const=True), virtual('compute_shape', returns='shape', input='std::vector<shape>', const=True),
virtual('compute', returns='argument', ctx='context&', output='shape', input='std::vector<argument>', const=True), virtual('compute', returns='argument', ctx='context&', output='shape', input='std::vector<argument>', const=True, default='compute_op'),
friend('operator<<', returns='std::ostream &', os='std::ostream &', op='const operation &', using='migraph::operation_stream::operator<<') friend('operator<<', returns='std::ostream &', os='std::ostream &', op='const operation &', using='migraph::operation_stream::operator<<')
) )
%> %>
......
...@@ -213,16 +213,21 @@ def internal_name(name): ...@@ -213,16 +213,21 @@ def internal_name(name):
else: else:
return name return name
def generate_call(m, friend): def generate_call(m, friend, indirect):
if m['name'].startswith('operator'): if m['name'].startswith('operator'):
op = m['name'][8:] op = m['name'][8:]
args = m['args'] args = m['args']
if ',' in args: if ',' in args:
return args.replace(',', op) return args.replace(',', op)
else: else:
return string.Template('${op}${arga}').substitute(op=op, args=args) return string.Template('${op}${args}').substitute(op=op, args=args)
if friend: if friend:
return string.Template('${name}(${args})').substitute(m) return string.Template('${name}(${args})').substitute(m)
if indirect:
if m['args']:
return string.Template('${default}(private_detail_te_value, ${args})').substitute(m)
else:
return string.Template('${default}(private_detail_te_value)').substitute(m)
return string.Template('private_detail_te_value.${name}(${args})').substitute(m) return string.Template('private_detail_te_value.${name}(${args})').substitute(m)
def convert_member(d, struct_name): def convert_member(d, struct_name):
...@@ -242,9 +247,12 @@ def convert_member(d, struct_name): ...@@ -242,9 +247,12 @@ def convert_member(d, struct_name):
member_params = [] member_params = []
skip = False skip = False
friend = False friend = False
indirect = False
if 'friend' in d[name]: if 'friend' in d[name]:
friend = True friend = True
skip = True skip = True
if 'default' in d[name]:
indirect = True
for x in d[name]: for x in d[name]:
t = d[name][x] t = d[name][x]
if x == 'return': if x == 'return':
...@@ -254,8 +262,12 @@ def convert_member(d, struct_name): ...@@ -254,8 +262,12 @@ def convert_member(d, struct_name):
member['member_const'] = 'const' member['member_const'] = 'const'
elif x == 'friend': elif x == 'friend':
member['friend'] = 'friend' member['friend'] = 'friend'
elif x == 'default':
member['default'] = t
elif x == 'using': elif x == 'using':
member['using'] = 'using {};'.format(d[name]['using']) member['using'] = 'using {};'.format(d[name]['using'])
elif x.startswith('__') and x.endswith('__'):
continue
else: else:
use_member = not(skip and struct_name == trim_type_name(t)) use_member = not(skip and struct_name == trim_type_name(t))
arg_name = x arg_name = x
...@@ -278,7 +290,7 @@ def convert_member(d, struct_name): ...@@ -278,7 +290,7 @@ def convert_member(d, struct_name):
member['params'] = ','.join(params) member['params'] = ','.join(params)
member['params'] = ','.join(params) member['params'] = ','.join(params)
member['member_params'] = ','.join(member_params) member['member_params'] = ','.join(member_params)
member['call'] = generate_call(member, friend) member['call'] = generate_call(member, friend, indirect)
return member return member
return None return None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment