Commit 0ac28b96 authored by Paul's avatar Paul
Browse files

Add finalization phase

parent 301b7605
...@@ -15,11 +15,19 @@ struct check_context ...@@ -15,11 +15,19 @@ struct check_context
std::string name() const { return "check_context"; } std::string name() const { return "check_context"; }
shape compute_shape(const std::vector<shape>&) const { return {}; } shape compute_shape(const std::vector<shape>&) const { return {}; }
argument compute(context& ctx, const shape&, const std::vector<argument>&) const argument compute(context& ctx, const shape&, const std::vector<argument>&) const
{
this->check(ctx);
return {};
}
void finalize(context& ctx, const shape&, const std::vector<shape>&) const
{
this->check(ctx);
}
void check(context& ctx) const
{ {
T* x = any_cast<T>(&ctx); T* x = any_cast<T>(&ctx);
if(x == nullptr) if(x == nullptr)
MIGRAPHX_THROW(std::string("Unexpected context type: ") + ctx.type_id().name()); MIGRAPHX_THROW(std::string("Unexpected context type: ") + ctx.type_id().name());
return {};
} }
}; };
......
...@@ -95,7 +95,7 @@ struct context ...@@ -95,7 +95,7 @@ struct context
void finish() const void finish() const
{ {
assert((*this).private_detail_te_handle_mem_var); assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().finish(); (*this).private_detail_te_get_handle().finish();
} }
private: private:
...@@ -136,7 +136,7 @@ struct context ...@@ -136,7 +136,7 @@ struct context
const std::type_info& type() const override { return typeid(private_detail_te_value); } const std::type_info& type() const override { return typeid(private_detail_te_value); }
void finish() const override { return private_detail_te_value.finish(); } void finish() const override { private_detail_te_value.finish(); }
PrivateDetailTypeErasedT private_detail_te_value; PrivateDetailTypeErasedT private_detail_te_value;
}; };
......
...@@ -14,6 +14,7 @@ namespace migraphx { ...@@ -14,6 +14,7 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
shape compute_shape(const operation& op, const std::vector<instruction_ref>& args); shape compute_shape(const operation& op, const std::vector<instruction_ref>& args);
std::vector<shape> to_shapes(const std::vector<instruction_ref>& args);
struct instruction struct instruction
{ {
...@@ -73,6 +74,8 @@ struct instruction ...@@ -73,6 +74,8 @@ struct instruction
argument eval() const; argument eval() const;
void finalize(context& ctx);
static instruction_ref get_output_alias(instruction_ref ins, bool shallow = false); static instruction_ref get_output_alias(instruction_ref ins, bool shallow = false);
private: private:
......
...@@ -26,6 +26,8 @@ struct operation ...@@ -26,6 +26,8 @@ struct operation
{ {
/// A unique name identifying the operation /// A unique name identifying the operation
std::string name() const; std::string name() const;
/// An optional method that can be used to finalize the operator before running
void finalize(context& ctx);
/// This is used to compute the resulting shape from an operation. If an /// This is used to compute the resulting shape from an operation. If an
/// operation cannot be run with input shapes, then it should throw an /// operation cannot be run with input shapes, then it should throw an
/// exception. /// exception.
...@@ -55,6 +57,8 @@ struct operation ...@@ -55,6 +57,8 @@ struct operation
/// Returns true if operation does not require a context to run compute /// Returns true if operation does not require a context to run compute
bool is_context_free(const operation& x); bool is_context_free(const operation& x);
/// Returns true if the operation has a finalize method
bool has_finalize(const operation& x);
#else #else
...@@ -189,6 +193,44 @@ int output_alias_op(const T& x, const std::vector<shape>& shapes) ...@@ -189,6 +193,44 @@ int output_alias_op(const T& x, const std::vector<shape>& shapes)
return output_alias_op(rank<1>{}, x, shapes); return output_alias_op(rank<1>{}, x, shapes);
} }
template <class T>
auto finalize_op(
rank<1>, T& x, context& ctx, const shape& output_shape, const std::vector<shape>& input)
-> decltype(x.finalize(auto_any_cast(ctx), output_shape, input), void())
{
x.finalize(auto_any_cast(ctx), output_shape, input);
}
template <class T>
void finalize_op(rank<0>, T&, context&, const shape&, const std::vector<shape>&)
{
}
template <class T>
void finalize_op(T& x, context& ctx, const shape& output_shape, const std::vector<shape>& input)
{
finalize_op(rank<1>{}, x, ctx, output_shape, input);
}
template <class T>
auto has_finalize_op(
rank<1>, T& x, context& ctx, const shape& output_shape, const std::vector<shape>& input)
-> decltype(x.finalize(auto_any_cast(ctx), output_shape, input), std::true_type{});
template <class T>
auto has_finalize_op(rank<0>, T&, context&, const shape&, const std::vector<shape>&)
-> std::false_type;
template <class T>
auto has_finalize_op(const T&) -> decltype(has_finalize_op(rank<1>{},
std::declval<T&>(),
std::declval<context&>(),
std::declval<const shape&>(),
std::declval<std::vector<shape>>()))
{
return {};
}
/* /*
* Type-erased interface for: * Type-erased interface for:
* *
...@@ -196,7 +238,9 @@ int output_alias_op(const T& x, const std::vector<shape>& shapes) ...@@ -196,7 +238,9 @@ int output_alias_op(const T& x, const std::vector<shape>& shapes)
* { * {
* std::string name() const; * std::string name() const;
* bool is_context_free() const; * bool is_context_free() const;
* bool has_finalize() const;
* int output_alias(const std::vector<shape>& input) const; * int output_alias(const std::vector<shape>& input) const;
* void finalize(context& ctx,const shape& output,const std::vector<shape>& input) ;
* shape compute_shape(const std::vector<shape>& input) const; * shape compute_shape(const std::vector<shape>& input) const;
* argument compute(context& ctx,const shape& output,const std::vector<argument>& input) const; * argument compute(context& ctx,const shape& output,const std::vector<argument>& input) const;
* argument compute(const shape& output,const std::vector<argument>& input) const; * argument compute(const shape& output,const std::vector<argument>& input) const;
...@@ -275,12 +319,24 @@ struct operation ...@@ -275,12 +319,24 @@ struct operation
return (*this).private_detail_te_get_handle().is_context_free(); return (*this).private_detail_te_get_handle().is_context_free();
} }
bool has_finalize() const
{
assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().has_finalize();
}
int output_alias(const std::vector<shape>& input) const int output_alias(const std::vector<shape>& input) const
{ {
assert((*this).private_detail_te_handle_mem_var); assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().output_alias(input); return (*this).private_detail_te_get_handle().output_alias(input);
} }
void finalize(context& ctx, const shape& output, const std::vector<shape>& input)
{
assert((*this).private_detail_te_handle_mem_var);
(*this).private_detail_te_get_handle().finalize(ctx, output, input);
}
shape compute_shape(const std::vector<shape>& input) const shape compute_shape(const std::vector<shape>& input) const
{ {
assert((*this).private_detail_te_handle_mem_var); assert((*this).private_detail_te_handle_mem_var);
...@@ -318,10 +374,13 @@ struct operation ...@@ -318,10 +374,13 @@ struct operation
virtual std::shared_ptr<private_detail_te_handle_base_type> clone() const = 0; virtual std::shared_ptr<private_detail_te_handle_base_type> clone() const = 0;
virtual const std::type_info& type() const = 0; virtual const std::type_info& type() const = 0;
virtual std::string name() const = 0; virtual std::string name() const = 0;
virtual bool is_context_free() const = 0; virtual bool is_context_free() const = 0;
virtual int output_alias(const std::vector<shape>& input) const = 0; virtual bool has_finalize() const = 0;
virtual shape compute_shape(const std::vector<shape>& input) const = 0; virtual int output_alias(const std::vector<shape>& input) const = 0;
virtual void
finalize(context& ctx, const shape& output, const std::vector<shape>& input) = 0;
virtual shape compute_shape(const std::vector<shape>& input) const = 0;
virtual argument virtual argument
compute(context& ctx, const shape& output, const std::vector<argument>& input) const = 0; compute(context& ctx, const shape& output, const std::vector<argument>& input) const = 0;
virtual argument compute(const shape& output, const std::vector<argument>& input) const = 0; virtual argument compute(const shape& output, const std::vector<argument>& input) const = 0;
...@@ -365,12 +424,20 @@ struct operation ...@@ -365,12 +424,20 @@ struct operation
return is_context_free_op(private_detail_te_value); return is_context_free_op(private_detail_te_value);
} }
bool has_finalize() const override { return has_finalize_op(private_detail_te_value); }
int output_alias(const std::vector<shape>& input) const override int output_alias(const std::vector<shape>& input) const override
{ {
return output_alias_op(private_detail_te_value, input); return output_alias_op(private_detail_te_value, input);
} }
void finalize(context& ctx, const shape& output, const std::vector<shape>& input) override
{
finalize_op(private_detail_te_value, ctx, output, input);
}
shape compute_shape(const std::vector<shape>& input) const override shape compute_shape(const std::vector<shape>& input) const override
{ {
...@@ -478,6 +545,14 @@ bool is_context_free(const T& x) ...@@ -478,6 +545,14 @@ bool is_context_free(const T& x)
return is_context_free_op(x); return is_context_free_op(x);
} }
inline bool has_finalize(const operation& op) { return op.has_finalize(); }
template <class T>
bool has_finalize(const T& x)
{
return has_finalize_op(x);
}
#endif #endif
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -105,7 +105,7 @@ struct pass ...@@ -105,7 +105,7 @@ struct pass
void apply(program& p) const void apply(program& p) const
{ {
assert((*this).private_detail_te_handle_mem_var); assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().apply(p); (*this).private_detail_te_get_handle().apply(p);
} }
private: private:
...@@ -149,7 +149,7 @@ struct pass ...@@ -149,7 +149,7 @@ struct pass
std::string name() const override { return private_detail_te_value.name(); } std::string name() const override { return private_detail_te_value.name(); }
void apply(program& p) const override { return private_detail_te_value.apply(p); } void apply(program& p) const override { private_detail_te_value.apply(p); }
PrivateDetailTypeErasedT private_detail_te_value; PrivateDetailTypeErasedT private_detail_te_value;
}; };
......
...@@ -95,6 +95,8 @@ struct program ...@@ -95,6 +95,8 @@ struct program
void compile(const target& t, tracer trace = tracer{}); void compile(const target& t, tracer trace = tracer{});
void finalize();
void perf_report(std::ostream& os, std::size_t n, parameter_map params) const; void perf_report(std::ostream& os, std::size_t n, parameter_map params) const;
void debug_print() const; void debug_print() const;
......
...@@ -162,14 +162,6 @@ void instruction::replace_argument(instruction_ref old, instruction_ref new_ins) ...@@ -162,14 +162,6 @@ void instruction::replace_argument(instruction_ref old, instruction_ref new_ins)
old->remove_output(*this); old->remove_output(*this);
} }
std::vector<shape> compute_shapes(const std::vector<instruction_ref>& args)
{
std::vector<shape> shapes(args.size());
std::transform(
args.begin(), args.end(), shapes.begin(), [](instruction_ref i) { return i->get_shape(); });
return shapes;
}
argument instruction::eval() const argument instruction::eval() const
{ {
if(op.name() == "@literal") if(op.name() == "@literal")
...@@ -191,9 +183,15 @@ argument instruction::eval() const ...@@ -191,9 +183,15 @@ argument instruction::eval() const
return {}; return {};
} }
void instruction::finalize(context& ctx)
{
if (has_finalize(this->op))
this->op.finalize(ctx, this->get_shape(), to_shapes(this->inputs()));
}
instruction_ref instruction::get_output_alias(instruction_ref ins, bool shallow) instruction_ref instruction::get_output_alias(instruction_ref ins, bool shallow)
{ {
auto i = ins->get_operator().output_alias(compute_shapes(ins->inputs())); auto i = ins->get_operator().output_alias(to_shapes(ins->inputs()));
if(i < 0) if(i < 0)
return ins; return ins;
if(shallow) if(shallow)
...@@ -201,9 +199,17 @@ instruction_ref instruction::get_output_alias(instruction_ref ins, bool shallow) ...@@ -201,9 +199,17 @@ instruction_ref instruction::get_output_alias(instruction_ref ins, bool shallow)
return get_output_alias(ins->inputs().at(i)); return get_output_alias(ins->inputs().at(i));
} }
std::vector<shape> to_shapes(const std::vector<instruction_ref>& args)
{
std::vector<shape> shapes(args.size());
std::transform(
args.begin(), args.end(), shapes.begin(), [](instruction_ref i) { return i->get_shape(); });
return shapes;
}
shape compute_shape(const operation& op, const std::vector<instruction_ref>& args) shape compute_shape(const operation& op, const std::vector<instruction_ref>& args)
{ {
return op.compute_shape(compute_shapes(args)); return op.compute_shape(to_shapes(args));
} }
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -309,6 +309,15 @@ void program::compile(const target& t, tracer trace) ...@@ -309,6 +309,15 @@ void program::compile(const target& t, tracer trace)
auto index = std::distance(impl->instructions.begin(), invalid); auto index = std::distance(impl->instructions.begin(), invalid);
MIGRAPHX_THROW("Invalid program from compilation at instruction " + std::to_string(index)); MIGRAPHX_THROW("Invalid program from compilation at instruction " + std::to_string(index));
} }
this->finalize();
}
void program::finalize()
{
for(auto ins : iterator_for(*this))
{
ins->finalize(this->impl->ctx);
}
} }
template <class F> template <class F>
......
...@@ -41,11 +41,11 @@ argument miopen_convolution::compute(context& ctx, ...@@ -41,11 +41,11 @@ argument miopen_convolution::compute(context& ctx,
shape miopen_convolution::compile(context& ctx, shape miopen_convolution::compile(context& ctx,
const shape& output_shape, const shape& output_shape,
std::vector<instruction_ref> inputs) std::vector<shape> inputs)
{ {
shape workspace_shape{}; shape workspace_shape{};
auto x_desc = make_tensor(inputs[0]->get_shape()); auto x_desc = make_tensor(inputs[0]);
auto w_desc = make_tensor(inputs[1]->get_shape()); auto w_desc = make_tensor(inputs[1]);
auto y_desc = make_tensor(output_shape); auto y_desc = make_tensor(output_shape);
std::size_t workspace_size = 0; std::size_t workspace_size = 0;
...@@ -57,8 +57,8 @@ shape miopen_convolution::compile(context& ctx, ...@@ -57,8 +57,8 @@ shape miopen_convolution::compile(context& ctx,
&workspace_size); &workspace_size);
workspace_shape = shape{shape::int8_type, {workspace_size}}; workspace_shape = shape{shape::int8_type, {workspace_size}};
auto x = to_gpu(generate_argument(inputs[0]->get_shape())); auto x = to_gpu(generate_argument(inputs[0]));
auto w = to_gpu(generate_argument(inputs[1]->get_shape())); auto w = to_gpu(generate_argument(inputs[1]));
auto y = allocate_gpu(output_shape); auto y = allocate_gpu(output_shape);
auto workspace = allocate_gpu(workspace_shape); auto workspace = allocate_gpu(workspace_shape);
...@@ -80,10 +80,20 @@ shape miopen_convolution::compile(context& ctx, ...@@ -80,10 +80,20 @@ shape miopen_convolution::compile(context& ctx,
false); false);
if(status != miopenStatusSuccess) if(status != miopenStatusSuccess)
MIGRAPHX_THROW("Find convolution failed"); MIGRAPHX_THROW("Find convolution failed");
handle = ctx.get_stream().get_miopen();
algo = perf.fwd_algo; algo = perf.fwd_algo;
return shape{shape::int8_type, {perf.memory}}; return shape{shape::int8_type, {perf.memory}};
} }
void miopen_convolution::finalize(context& ctx, const shape& output_shape, std::vector<shape> inputs)
{
if (handle == ctx.get_stream().get_miopen())
return;
// TODO: Check that workspace hasn't changed
compile(ctx, output_shape, inputs);
}
} // namespace gpu } // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
...@@ -274,9 +274,12 @@ struct miopen_conv_bias ...@@ -274,9 +274,12 @@ struct miopen_conv_bias
return f.execute(ctx, fargs, args[0], args[4]); return f.execute(ctx, fargs, args[0], args[4]);
} }
shape compile(context& ctx) void finalize(context& ctx, const shape&, const std::vector<shape>&)
{ {
f.compile(ctx); f.compile(ctx);
}
shape get_workspace(context& ctx)
{
return f.get_workspace(ctx); return f.get_workspace(ctx);
} }
int output_alias(const std::vector<shape>& shapes) const { return shapes.size() - 1; } int output_alias(const std::vector<shape>& shapes) const { return shapes.size() - 1; }
...@@ -318,10 +321,12 @@ struct miopen_conv_bias_relu ...@@ -318,10 +321,12 @@ struct miopen_conv_bias_relu
miopenSetOpArgsActivForward(fargs.get(), relu, &alpha, &beta, 0, 0, 0); miopenSetOpArgsActivForward(fargs.get(), relu, &alpha, &beta, 0, 0, 0);
return f.execute(ctx, fargs, args[0], args[4]); return f.execute(ctx, fargs, args[0], args[4]);
} }
void finalize(context& ctx, const shape&, const std::vector<shape>&)
shape compile(context& ctx)
{ {
f.compile(ctx); f.compile(ctx);
}
shape get_workspace(context& ctx)
{
return f.get_workspace(ctx); return f.get_workspace(ctx);
} }
int output_alias(const std::vector<shape>& shapes) const { return shapes.size() - 1; } int output_alias(const std::vector<shape>& shapes) const { return shapes.size() - 1; }
...@@ -350,7 +355,7 @@ void apply_conv_bias(context& ctx, program& p, match::matcher_result r) ...@@ -350,7 +355,7 @@ void apply_conv_bias(context& ctx, program& p, match::matcher_result r)
Op cb{conv_op, input_ins->get_shape(), weights_ins->get_shape(), bias_ins->get_shape()}; Op cb{conv_op, input_ins->get_shape(), weights_ins->get_shape(), bias_ins->get_shape()};
// TODO: Insert ws allocation // TODO: Insert ws allocation
auto ws = cb.compile(ctx); auto ws = cb.get_workspace(ctx);
p.replace_instruction(ins, cb, input_ins, weights_ins, old_ws_ins, bias_ins, alloc_ins); p.replace_instruction(ins, cb, input_ins, weights_ins, old_ws_ins, bias_ins, alloc_ins);
} }
......
...@@ -27,6 +27,7 @@ struct miopen_convolution ...@@ -27,6 +27,7 @@ struct miopen_convolution
op::convolution op; op::convolution op;
shared<convolution_descriptor> cd; shared<convolution_descriptor> cd;
miopenConvFwdAlgorithm_t algo{}; miopenConvFwdAlgorithm_t algo{};
miopenHandle_t handle = nullptr;
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f) static auto reflect(Self& self, F f)
...@@ -39,7 +40,8 @@ struct miopen_convolution ...@@ -39,7 +40,8 @@ struct miopen_convolution
shape compute_shape(const std::vector<shape>& inputs) const; shape compute_shape(const std::vector<shape>& inputs) const;
argument argument
compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const; compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const;
shape compile(context& ctx, const shape& output_shape, std::vector<instruction_ref> inputs); shape compile(context& ctx, const shape& output_shape, std::vector<shape> inputs);
void finalize(context& ctx, const shape& output_shape, std::vector<shape> inputs);
int output_alias(const std::vector<shape>& shapes) const { return shapes.size() - 1; } int output_alias(const std::vector<shape>& shapes) const { return shapes.size() - 1; }
}; };
......
...@@ -129,7 +129,7 @@ struct miopen_apply ...@@ -129,7 +129,7 @@ struct miopen_apply
auto&& op = any_cast<op::convolution>(ins->get_operator()); auto&& op = any_cast<op::convolution>(ins->get_operator());
auto conv = miopen_convolution{op, make_conv(op)}; auto conv = miopen_convolution{op, make_conv(op)};
auto ws = conv.compile(ctx, ins->get_shape(), ins->inputs()); auto ws = conv.compile(ctx, ins->get_shape(), to_shapes(ins->inputs()));
auto workspace = insert_allocation(ins, ws, "workspace"); auto workspace = insert_allocation(ins, ws, "workspace");
auto output = insert_allocation(ins, ins->get_shape()); auto output = insert_allocation(ins, ins->get_shape());
......
...@@ -26,6 +26,8 @@ struct operation ...@@ -26,6 +26,8 @@ struct operation
{ {
/// A unique name identifying the operation /// A unique name identifying the operation
std::string name() const; std::string name() const;
/// An optional method that can be used to finalize the operator before running
void finalize(context& ctx);
/// This is used to compute the resulting shape from an operation. If an /// This is used to compute the resulting shape from an operation. If an
/// operation cannot be run with input shapes, then it should throw an /// operation cannot be run with input shapes, then it should throw an
/// exception. /// exception.
...@@ -55,6 +57,8 @@ struct operation ...@@ -55,6 +57,8 @@ struct operation
/// Returns true if operation does not require a context to run compute /// Returns true if operation does not require a context to run compute
bool is_context_free(const operation& x); bool is_context_free(const operation& x);
/// Returns true if the operation has a finalize method
bool has_finalize(const operation& x);
#else #else
...@@ -189,16 +193,54 @@ int output_alias_op(const T& x, const std::vector<shape>& shapes) ...@@ -189,16 +193,54 @@ int output_alias_op(const T& x, const std::vector<shape>& shapes)
return output_alias_op(rank<1>{}, x, shapes); return output_alias_op(rank<1>{}, x, shapes);
} }
template <class T>
auto finalize_op(rank<1>, T& x, context& ctx, const shape& output_shape, const std::vector<shape>& input)
-> decltype(x.finalize(auto_any_cast(ctx), output_shape, input), void())
{
x.finalize(auto_any_cast(ctx), output_shape, input);
}
template <class T>
void finalize_op(rank<0>, T&, context&, const shape&, const std::vector<shape>&)
{}
template <class T>
void finalize_op(T& x, context& ctx, const shape& output_shape, const std::vector<shape>& input)
{
finalize_op(rank<1>{}, x, ctx, output_shape, input);
}
template <class T>
auto has_finalize_op(rank<1>,
T& x,
context& ctx,
const shape& output_shape,
const std::vector<shape>& input)
-> decltype(x.finalize(auto_any_cast(ctx), output_shape, input), std::true_type{});
template <class T>
auto has_finalize_op(rank<0>, T&, context&, const shape&, const std::vector<shape>&)
-> std::false_type;
template <class T>
auto has_finalize_op(const T&) -> decltype(has_finalize_op(
rank<1>{}, std::declval<T&>(), std::declval<context&>(), std::declval<const shape&>(), std::declval<std::vector<shape>>()))
{
return {};
}
<% <%
interface( interface(
'operation', 'operation',
virtual('name', returns = 'std::string', const = True), virtual('name', returns = 'std::string', const = True),
virtual('is_context_free', returns = 'bool', const = True, default = 'is_context_free_op'), virtual('is_context_free', returns = 'bool', const = True, default = 'is_context_free_op'),
virtual('has_finalize', returns = 'bool', const = True, default = 'has_finalize_op'),
virtual('output_alias', virtual('output_alias',
returns = 'int', returns = 'int',
input = 'const std::vector<shape>&', input = 'const std::vector<shape>&',
const = True, const = True,
default = 'output_alias_op'), default = 'output_alias_op'),
virtual('finalize', ctx = 'context&', output = 'const shape&', input = 'const std::vector<shape>&', default = 'finalize_op'),
virtual('compute_shape', returns = 'shape', input = 'const std::vector<shape>&', const = True), virtual('compute_shape', returns = 'shape', input = 'const std::vector<shape>&', const = True),
virtual('compute', virtual('compute',
returns = 'argument', returns = 'argument',
...@@ -237,6 +279,14 @@ bool is_context_free(const T& x) ...@@ -237,6 +279,14 @@ bool is_context_free(const T& x)
return is_context_free_op(x); return is_context_free_op(x);
} }
inline bool has_finalize(const operation& op) { return op.has_finalize(); }
template <class T>
bool has_finalize(const T& x)
{
return has_finalize_op(x);
}
#endif #endif
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -179,7 +179,7 @@ nonvirtual_member = string.Template(''' ...@@ -179,7 +179,7 @@ nonvirtual_member = string.Template('''
${friend} ${return_type} ${name}(${params}) ${const} ${friend} ${return_type} ${name}(${params}) ${const}
{ {
assert(${this}.private_detail_te_handle_mem_var); assert(${this}.private_detail_te_handle_mem_var);
return ${this}.private_detail_te_get_handle().${internal_name}(${member_args}); ${return_} ${this}.private_detail_te_get_handle().${internal_name}(${member_args});
} }
''') ''')
...@@ -189,7 +189,7 @@ virtual_member = string.Template(''' ...@@ -189,7 +189,7 @@ virtual_member = string.Template('''
${return_type} ${internal_name}(${member_params}) ${member_const} override ${return_type} ${internal_name}(${member_params}) ${member_const} override
{ {
${using} ${using}
return ${call}; ${return_} ${call};
} }
''') ''')
...@@ -240,7 +240,8 @@ def convert_member(d, struct_name): ...@@ -240,7 +240,8 @@ def convert_member(d, struct_name):
'friend': '', 'friend': '',
'this': '(*this)', 'this': '(*this)',
'using': '', 'using': '',
'brief': '' 'brief': '',
'return_': ''
} }
args = [] args = []
params = [] params = []
...@@ -257,7 +258,8 @@ def convert_member(d, struct_name): ...@@ -257,7 +258,8 @@ def convert_member(d, struct_name):
for x in d[name]: for x in d[name]:
t = d[name][x] t = d[name][x]
if x == 'return': if x == 'return':
member['return_type'] = t member['return_type'] = t if t else 'void'
if member['return_type'] != 'void': member['return_'] = 'return'
elif x == 'const': elif x == 'const':
member['const'] = 'const' member['const'] = 'const'
member['member_const'] = 'const' member['member_const'] = 'const'
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment