Commit fa4cf4c4 authored by Khalique's avatar Khalique
Browse files

Merge branch 'develop' of https://github.com/ROCmSoftwarePlatform/AMDMIGraphX into pad_op

parents e067dce0 31b2c735
...@@ -103,4 +103,64 @@ TEST_CASE(operation_default_print) ...@@ -103,4 +103,64 @@ TEST_CASE(operation_default_print)
EXPECT(s == "simple"); EXPECT(s == "simple");
} }
struct final_operation
{
std::string name() const { return "final"; }
migraphx::shape compute_shape(const std::vector<migraphx::shape>&) const
{
MIGRAPHX_THROW("not computable");
}
void
finalize(migraphx::context&, const migraphx::shape&, const std::vector<migraphx::shape>&) const
{
}
};
struct final_operation_throw
{
std::string name() const { return "final"; }
migraphx::shape compute_shape(const std::vector<migraphx::shape>&) const
{
MIGRAPHX_THROW("not computable");
}
[[gnu::noreturn]] void
finalize(migraphx::context&, const migraphx::shape&, const std::vector<migraphx::shape>&) const
{
MIGRAPHX_THROW("finalize");
}
};
TEST_CASE(check_has_finalize_simple)
{
migraphx::operation op = simple_operation{};
EXPECT(not migraphx::has_finalize(op));
}
TEST_CASE(check_has_finalize)
{
migraphx::operation op = final_operation{};
EXPECT(migraphx::has_finalize(op));
}
TEST_CASE(check_run_finalize)
{
migraphx::operation op = final_operation{};
migraphx::context ctx{};
op.finalize(ctx, {}, {});
}
TEST_CASE(check_run_finalize_simple)
{
migraphx::operation op = simple_operation{};
migraphx::context ctx{};
op.finalize(ctx, {}, {});
}
TEST_CASE(check_run_finalize_throw)
{
migraphx::operation op = final_operation_throw{};
migraphx::context ctx{};
EXPECT(test::throws([&] { op.finalize(ctx, {}, {}); }));
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -26,6 +26,8 @@ struct operation ...@@ -26,6 +26,8 @@ struct operation
{ {
/// A unique name identifying the operation /// A unique name identifying the operation
std::string name() const; std::string name() const;
/// An optional method that can be used to finalize the operator before running
void finalize(context& ctx);
/// This is used to compute the resulting shape from an operation. If an /// This is used to compute the resulting shape from an operation. If an
/// operation cannot be run with input shapes, then it should throw an /// operation cannot be run with input shapes, then it should throw an
/// exception. /// exception.
...@@ -55,6 +57,8 @@ struct operation ...@@ -55,6 +57,8 @@ struct operation
/// Returns true if operation does not require a context to run compute /// Returns true if operation does not require a context to run compute
bool is_context_free(const operation& x); bool is_context_free(const operation& x);
/// Returns true if the operation has a finalize method
bool has_finalize(const operation& x);
#else #else
...@@ -189,16 +193,60 @@ int output_alias_op(const T& x, const std::vector<shape>& shapes) ...@@ -189,16 +193,60 @@ int output_alias_op(const T& x, const std::vector<shape>& shapes)
return output_alias_op(rank<1>{}, x, shapes); return output_alias_op(rank<1>{}, x, shapes);
} }
template <class T>
auto finalize_op(
rank<1>, T& x, context& ctx, const shape& output_shape, const std::vector<shape>& input)
-> decltype(x.finalize(auto_any_cast(ctx), output_shape, input), void())
{
x.finalize(auto_any_cast(ctx), output_shape, input);
}
template <class T>
void finalize_op(rank<0>, T&, context&, const shape&, const std::vector<shape>&)
{
}
template <class T>
void finalize_op(T& x, context& ctx, const shape& output_shape, const std::vector<shape>& input)
{
finalize_op(rank<1>{}, x, ctx, output_shape, input);
}
template <class T>
auto has_finalize_op(
rank<1>, T& x, context& ctx, const shape& output_shape, const std::vector<shape>& input)
-> decltype(x.finalize(auto_any_cast(ctx), output_shape, input), std::true_type{});
template <class T>
auto has_finalize_op(rank<0>, T&, context&, const shape&, const std::vector<shape>&)
-> std::false_type;
template <class T>
auto has_finalize_op(const T&) -> decltype(has_finalize_op(rank<1>{},
std::declval<T&>(),
std::declval<context&>(),
std::declval<const shape&>(),
std::declval<std::vector<shape>>()))
{
return {};
}
<% <%
interface( interface(
'operation', 'operation',
virtual('name', returns = 'std::string', const = True), virtual('name', returns = 'std::string', const = True),
virtual('is_context_free', returns = 'bool', const = True, default = 'is_context_free_op'), virtual('is_context_free', returns = 'bool', const = True, default = 'is_context_free_op'),
virtual('has_finalize', returns = 'bool', const = True, default = 'has_finalize_op'),
virtual('output_alias', virtual('output_alias',
returns = 'int', returns = 'int',
input = 'const std::vector<shape>&', input = 'const std::vector<shape>&',
const = True, const = True,
default = 'output_alias_op'), default = 'output_alias_op'),
virtual('finalize',
ctx = 'context&',
output = 'const shape&',
input = 'const std::vector<shape>&',
default = 'finalize_op'),
virtual('compute_shape', returns = 'shape', input = 'const std::vector<shape>&', const = True), virtual('compute_shape', returns = 'shape', input = 'const std::vector<shape>&', const = True),
virtual('compute', virtual('compute',
returns = 'argument', returns = 'argument',
...@@ -237,6 +285,14 @@ bool is_context_free(const T& x) ...@@ -237,6 +285,14 @@ bool is_context_free(const T& x)
return is_context_free_op(x); return is_context_free_op(x);
} }
inline bool has_finalize(const operation& op) { return op.has_finalize(); }
template <class T>
bool has_finalize(const T& x)
{
return has_finalize_op(x);
}
#endif #endif
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -71,6 +71,11 @@ struct ${struct_name} ...@@ -71,6 +71,11 @@ struct ${struct_name}
${nonvirtual_members} ${nonvirtual_members}
friend bool is_shared(const ${struct_name} & private_detail_x, const ${struct_name} & private_detail_y)
{
return private_detail_x.private_detail_te_handle_mem_var == private_detail_y.private_detail_te_handle_mem_var;
}
private: private:
struct private_detail_te_handle_base_type struct private_detail_te_handle_base_type
{ {
...@@ -179,7 +184,7 @@ nonvirtual_member = string.Template(''' ...@@ -179,7 +184,7 @@ nonvirtual_member = string.Template('''
${friend} ${return_type} ${name}(${params}) ${const} ${friend} ${return_type} ${name}(${params}) ${const}
{ {
assert(${this}.private_detail_te_handle_mem_var); assert(${this}.private_detail_te_handle_mem_var);
return ${this}.private_detail_te_get_handle().${internal_name}(${member_args}); ${return_} ${this}.private_detail_te_get_handle().${internal_name}(${member_args});
} }
''') ''')
...@@ -189,7 +194,7 @@ virtual_member = string.Template(''' ...@@ -189,7 +194,7 @@ virtual_member = string.Template('''
${return_type} ${internal_name}(${member_params}) ${member_const} override ${return_type} ${internal_name}(${member_params}) ${member_const} override
{ {
${using} ${using}
return ${call}; ${return_} ${call};
} }
''') ''')
...@@ -240,7 +245,8 @@ def convert_member(d, struct_name): ...@@ -240,7 +245,8 @@ def convert_member(d, struct_name):
'friend': '', 'friend': '',
'this': '(*this)', 'this': '(*this)',
'using': '', 'using': '',
'brief': '' 'brief': '',
'return_': ''
} }
args = [] args = []
params = [] params = []
...@@ -257,7 +263,8 @@ def convert_member(d, struct_name): ...@@ -257,7 +263,8 @@ def convert_member(d, struct_name):
for x in d[name]: for x in d[name]:
t = d[name][x] t = d[name][x]
if x == 'return': if x == 'return':
member['return_type'] = t member['return_type'] = t if t else 'void'
if member['return_type'] != 'void': member['return_'] = 'return'
elif x == 'const': elif x == 'const':
member['const'] = 'const' member['const'] = 'const'
member['member_const'] = 'const' member['member_const'] = 'const'
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment