Commit 4e3ca586 authored by Khalique's avatar Khalique
Browse files

Merge branch 'develop' of https://github.com/ROCmSoftwarePlatform/AMDMIGraphX into broadcast_attr

parents 1775c5ad 31b2c735
gather-example:Ž
'
data
indicesy"Gather*
axis  test_gatherZ
data




Z
indices


b
y




B
\ No newline at end of file
group_conv-example:

0
12"Conv*
grouptest-group_convZ
0




Z
1




b
2




B
\ No newline at end of file
...@@ -351,8 +351,8 @@ TEST_CASE(implicit_bcast_test) ...@@ -351,8 +351,8 @@ TEST_CASE(implicit_bcast_test)
migraphx::program p; migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}); auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {3, 4}}); auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {3, 4}});
auto l2 = p.add_instruction(migraphx::op::multibroadcast{{0, 0, 4, 5}}, l0); auto l2 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4, 5}}, l0);
auto l3 = p.add_instruction(migraphx::op::multibroadcast{{0, 0, 4, 5}}, l1); auto l3 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4, 5}}, l1);
p.add_instruction(migraphx::op::add{}, l2, l3); p.add_instruction(migraphx::op::add{}, l2, l3);
auto prog = migraphx::parse_onnx("implicit_bcast_test.onnx"); auto prog = migraphx::parse_onnx("implicit_bcast_test.onnx");
...@@ -400,6 +400,45 @@ TEST_CASE(reshape_test) ...@@ -400,6 +400,45 @@ TEST_CASE(reshape_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(shape_test)
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3, 4, 5, 6}};
auto l0 = p.add_parameter("x", s);
migraphx::shape s_shape{migraphx::shape::int64_type, {4}};
p.add_literal(s_shape, l0->get_shape().lens());
auto prog = migraphx::parse_onnx("shape_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(gather_test)
{
migraphx::program p;
auto l0 = p.add_parameter("data", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}});
auto l1 = p.add_parameter("indices", migraphx::shape{migraphx::shape::int32_type, {2, 3}});
std::size_t axis = 1;
p.add_instruction(migraphx::op::gather{axis}, l0, l1);
auto prog = migraphx::parse_onnx("gather_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(shape_gather_test)
{
migraphx::program p;
auto l0 = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {7, 3, 10}});
auto l1 =
p.add_literal(migraphx::shape{migraphx::shape::int64_type, {3}}, l0->get_shape().lens());
migraphx::shape const_shape{migraphx::shape::int32_type, {1}};
auto l2 = p.add_literal(migraphx::literal{const_shape, {1}});
std::size_t axis = 0;
p.add_instruction(migraphx::op::gather{axis}, l1, l2);
auto prog = migraphx::parse_onnx("shape_gather.onnx");
EXPECT(p == prog);
}
TEST_CASE(flatten_test) TEST_CASE(flatten_test)
{ {
migraphx::program p; migraphx::program p;
...@@ -455,17 +494,43 @@ TEST_CASE(constant_test) ...@@ -455,17 +494,43 @@ TEST_CASE(constant_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(constant_fill_test)
{
{
migraphx::program p;
auto l0 = p.add_literal(migraphx::literal{{migraphx::shape::int32_type, {2}}, {2, 3}});
std::vector<std::size_t> dims(l0->get_shape().elements());
migraphx::literal ls = l0->get_literal();
ls.visit([&](auto s) { dims.assign(s.begin(), s.end()); });
migraphx::shape s{migraphx::shape::float_type, dims};
std::vector<float> value(s.elements(), 1.0);
p.add_literal(migraphx::literal{s, value});
auto prog = migraphx::parse_onnx("const_fill1.onnx");
EXPECT(p == prog);
}
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {2, 3}};
std::vector<float> value(s.elements(), 1.0);
p.add_literal(migraphx::literal{s, value});
auto prog = migraphx::parse_onnx("const_fill2.onnx");
EXPECT(p == prog);
}
}
TEST_CASE(gemm_test) TEST_CASE(gemm_test)
{ {
migraphx::program p; migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {5, 7}}); auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {5, 7}});
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {11, 5}}); auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {11, 5}});
auto l2 = p.add_parameter("2", migraphx::shape{migraphx::shape::float_type, {}}); p.add_parameter("2", migraphx::shape{migraphx::shape::float_type, {}});
auto t0 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l0); auto t0 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l0);
auto t1 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l1); auto t1 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l1);
auto d0 = p.add_instruction(migraphx::op::dot{2, 2}, t0, t1); auto alpha = 2.f;
auto b0 = p.add_instruction(migraphx::op::broadcast{1, d0->get_shape()}, l2); p.add_instruction(migraphx::op::dot{alpha}, t0, t1);
p.add_instruction(migraphx::op::add{}, d0, b0);
auto prog = migraphx::parse_onnx("gemm_test.onnx"); auto prog = migraphx::parse_onnx("gemm_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
...@@ -477,12 +542,23 @@ TEST_CASE(add_scalar_test) ...@@ -477,12 +542,23 @@ TEST_CASE(add_scalar_test)
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}); auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
auto l1 = auto l1 =
p.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::float_type, {1}}, {1}}); p.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::float_type, {1}}, {1}});
auto m0 = p.add_instruction(migraphx::op::multibroadcast{{0, 0, 0, 5}}, l0); auto m0 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4, 5}}, l0);
auto m1 = p.add_instruction(migraphx::op::multibroadcast{{0, 0, 0, 5}}, l1); auto m1 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4, 5}}, l1);
p.add_instruction(migraphx::op::add{}, m0, m1); p.add_instruction(migraphx::op::add{}, m0, m1);
auto prog = migraphx::parse_onnx("add_scalar_test.onnx"); auto prog = migraphx::parse_onnx("add_scalar_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(group_conv_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 4, 16, 16}});
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {4, 1, 3, 3}});
migraphx::op::convolution op;
op.group = 4;
p.add_instruction(op, l0, l1);
auto prog = migraphx::parse_onnx("group_conv_test.onnx");
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
 shape-example:I
xy"Shape
test_shapeZ
x




b
y

B
\ No newline at end of file
...@@ -212,4 +212,24 @@ TEST_CASE(multibroadcast) ...@@ -212,4 +212,24 @@ TEST_CASE(multibroadcast)
} }
} }
TEST_CASE(gather)
{
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
migraphx::shape indices{migraphx::shape::int32_type, {2, 3}};
std::size_t axis = 1;
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 6, 4, 5}},
migraphx::op::gather{axis},
input,
indices);
}
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
migraphx::shape indices{migraphx::shape::int32_type, {2, 3}};
std::size_t axis = 4;
throws_shape(migraphx::op::gather{axis}, input, indices);
}
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -103,4 +103,64 @@ TEST_CASE(operation_default_print) ...@@ -103,4 +103,64 @@ TEST_CASE(operation_default_print)
EXPECT(s == "simple"); EXPECT(s == "simple");
} }
struct final_operation
{
std::string name() const { return "final"; }
migraphx::shape compute_shape(const std::vector<migraphx::shape>&) const
{
MIGRAPHX_THROW("not computable");
}
void
finalize(migraphx::context&, const migraphx::shape&, const std::vector<migraphx::shape>&) const
{
}
};
struct final_operation_throw
{
std::string name() const { return "final"; }
migraphx::shape compute_shape(const std::vector<migraphx::shape>&) const
{
MIGRAPHX_THROW("not computable");
}
[[gnu::noreturn]] void
finalize(migraphx::context&, const migraphx::shape&, const std::vector<migraphx::shape>&) const
{
MIGRAPHX_THROW("finalize");
}
};
TEST_CASE(check_has_finalize_simple)
{
migraphx::operation op = simple_operation{};
EXPECT(not migraphx::has_finalize(op));
}
TEST_CASE(check_has_finalize)
{
migraphx::operation op = final_operation{};
EXPECT(migraphx::has_finalize(op));
}
TEST_CASE(check_run_finalize)
{
migraphx::operation op = final_operation{};
migraphx::context ctx{};
op.finalize(ctx, {}, {});
}
TEST_CASE(check_run_finalize_simple)
{
migraphx::operation op = simple_operation{};
migraphx::context ctx{};
op.finalize(ctx, {}, {});
}
TEST_CASE(check_run_finalize_throw)
{
migraphx::operation op = final_operation_throw{};
migraphx::context ctx{};
EXPECT(test::throws([&] { op.finalize(ctx, {}, {}); }));
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -117,4 +117,21 @@ TEST_CASE(single_transpose_sin_pass) ...@@ -117,4 +117,21 @@ TEST_CASE(single_transpose_sin_pass)
EXPECT(result != get_2x2()); EXPECT(result != get_2x2());
} }
TEST_CASE(reshape_transpose)
{
migraphx::program p;
auto s = migraphx::shape{migraphx::shape::float_type, {1, 112, 56, 56}};
auto x = p.add_parameter("x", s);
auto r1 = p.add_instruction(migraphx::op::reshape{{1, 4, 28, 56, 56}}, x);
auto t = p.add_instruction(migraphx::op::transpose{{0, 2, 1, 3, 4}}, r1);
auto ct = p.add_instruction(migraphx::op::contiguous{}, t);
auto r2 = p.add_instruction(migraphx::op::reshape{{1, 112, 56, 56}}, ct);
p.add_instruction(pass_op{}, r2);
EXPECT(p.get_shape() == s);
auto n = std::distance(p.begin(), p.end());
p.compile(simplify_reshapes_target{});
EXPECT(p.get_shape() == s);
EXPECT(std::distance(p.begin(), p.end()) == n);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -26,6 +26,8 @@ struct operation ...@@ -26,6 +26,8 @@ struct operation
{ {
/// A unique name identifying the operation /// A unique name identifying the operation
std::string name() const; std::string name() const;
/// An optional method that can be used to finalize the operator before running
void finalize(context& ctx);
/// This is used to compute the resulting shape from an operation. If an /// This is used to compute the resulting shape from an operation. If an
/// operation cannot be run with input shapes, then it should throw an /// operation cannot be run with input shapes, then it should throw an
/// exception. /// exception.
...@@ -53,6 +55,11 @@ struct operation ...@@ -53,6 +55,11 @@ struct operation
friend std::ostream& operator<<(std::ostream& os, const operation& op); friend std::ostream& operator<<(std::ostream& os, const operation& op);
}; };
/// Returns true if operation does not require a context to run compute
bool is_context_free(const operation& x);
/// Returns true if the operation has a finalize method
bool has_finalize(const operation& x);
#else #else
namespace operation_stream { namespace operation_stream {
...@@ -89,7 +96,7 @@ auto operator==(const T& x, const U& y) -> decltype(x.name() == y.name()) ...@@ -89,7 +96,7 @@ auto operator==(const T& x, const U& y) -> decltype(x.name() == y.name())
} // namespace operation_equal } // namespace operation_equal
template <class T> template <class T>
auto compute_op(rank<1>, auto compute_op(rank<2>,
const T& x, const T& x,
context& ctx, context& ctx,
const shape& output_shape, const shape& output_shape,
...@@ -99,6 +106,14 @@ auto compute_op(rank<1>, ...@@ -99,6 +106,14 @@ auto compute_op(rank<1>,
return x.compute(auto_any_cast(ctx), output_shape, input); return x.compute(auto_any_cast(ctx), output_shape, input);
} }
template <class T>
auto compute_op(
rank<1>, const T& x, context&, const shape& output_shape, const std::vector<argument>& input)
-> decltype(x.compute(output_shape, input))
{
return x.compute(output_shape, input);
}
template <class T> template <class T>
argument compute_op(rank<0>, const T& x, context&, const shape&, const std::vector<argument>&) argument compute_op(rank<0>, const T& x, context&, const shape&, const std::vector<argument>&)
{ {
...@@ -110,7 +125,53 @@ template <class T> ...@@ -110,7 +125,53 @@ template <class T>
argument argument
compute_op(const T& x, context& ctx, const shape& output_shape, const std::vector<argument>& input) compute_op(const T& x, context& ctx, const shape& output_shape, const std::vector<argument>& input)
{ {
return compute_op(rank<1>{}, x, ctx, output_shape, input); return compute_op(rank<2>{}, x, ctx, output_shape, input);
}
template <class T>
auto compute_op(rank<2>, const T& x, const shape& output_shape, const std::vector<argument>& input)
-> decltype(x.compute(output_shape, input))
{
return x.compute(output_shape, input);
}
template <class T>
auto compute_op(rank<1>, const T& x, const shape& output_shape, const std::vector<argument>& input)
-> decltype(x.compute(auto_any_cast(std::declval<context&>()), output_shape, input))
{
std::string name = x.name();
MIGRAPHX_THROW("Not computable without a context: " + name);
}
template <class T>
argument compute_op(rank<0>, const T& x, const shape&, const std::vector<argument>&)
{
std::string name = x.name();
MIGRAPHX_THROW("Not computable: " + name);
}
template <class T>
argument compute_op(const T& x, const shape& output_shape, const std::vector<argument>& input)
{
return compute_op(rank<2>{}, x, output_shape, input);
}
template <class T>
auto is_context_free_op(rank<1>,
const T& x,
const shape& output_shape,
const std::vector<argument>& input)
-> decltype(x.compute(output_shape, input), std::true_type{});
template <class T>
auto is_context_free_op(rank<0>, const T&, const shape&, const std::vector<argument>&)
-> std::false_type;
template <class T>
auto is_context_free_op(const T& x) -> decltype(is_context_free_op(
rank<1>{}, x, std::declval<const shape&>(), std::declval<std::vector<argument>>()))
{
return {};
} }
template <class T> template <class T>
...@@ -132,15 +193,60 @@ int output_alias_op(const T& x, const std::vector<shape>& shapes) ...@@ -132,15 +193,60 @@ int output_alias_op(const T& x, const std::vector<shape>& shapes)
return output_alias_op(rank<1>{}, x, shapes); return output_alias_op(rank<1>{}, x, shapes);
} }
template <class T>
auto finalize_op(
rank<1>, T& x, context& ctx, const shape& output_shape, const std::vector<shape>& input)
-> decltype(x.finalize(auto_any_cast(ctx), output_shape, input), void())
{
x.finalize(auto_any_cast(ctx), output_shape, input);
}
template <class T>
void finalize_op(rank<0>, T&, context&, const shape&, const std::vector<shape>&)
{
}
template <class T>
void finalize_op(T& x, context& ctx, const shape& output_shape, const std::vector<shape>& input)
{
finalize_op(rank<1>{}, x, ctx, output_shape, input);
}
template <class T>
auto has_finalize_op(
rank<1>, T& x, context& ctx, const shape& output_shape, const std::vector<shape>& input)
-> decltype(x.finalize(auto_any_cast(ctx), output_shape, input), std::true_type{});
template <class T>
auto has_finalize_op(rank<0>, T&, context&, const shape&, const std::vector<shape>&)
-> std::false_type;
template <class T>
auto has_finalize_op(const T&) -> decltype(has_finalize_op(rank<1>{},
std::declval<T&>(),
std::declval<context&>(),
std::declval<const shape&>(),
std::declval<std::vector<shape>>()))
{
return {};
}
<% <%
interface( interface(
'operation', 'operation',
virtual('name', returns = 'std::string', const = True), virtual('name', returns = 'std::string', const = True),
virtual('is_context_free', returns = 'bool', const = True, default = 'is_context_free_op'),
virtual('has_finalize', returns = 'bool', const = True, default = 'has_finalize_op'),
virtual('output_alias', virtual('output_alias',
returns = 'int', returns = 'int',
input = 'const std::vector<shape>&', input = 'const std::vector<shape>&',
const = True, const = True,
default = 'output_alias_op'), default = 'output_alias_op'),
virtual('finalize',
ctx = 'context&',
output = 'const shape&',
input = 'const std::vector<shape>&',
default = 'finalize_op'),
virtual('compute_shape', returns = 'shape', input = 'const std::vector<shape>&', const = True), virtual('compute_shape', returns = 'shape', input = 'const std::vector<shape>&', const = True),
virtual('compute', virtual('compute',
returns = 'argument', returns = 'argument',
...@@ -149,6 +255,12 @@ int output_alias_op(const T& x, const std::vector<shape>& shapes) ...@@ -149,6 +255,12 @@ int output_alias_op(const T& x, const std::vector<shape>& shapes)
input = 'const std::vector<argument>&', input = 'const std::vector<argument>&',
const = True, const = True,
default = 'compute_op'), default = 'compute_op'),
virtual('compute',
returns = 'argument',
output = 'const shape&',
input = 'const std::vector<argument>&',
const = True,
default = 'compute_op'),
friend('operator<<', friend('operator<<',
returns = 'std::ostream &', returns = 'std::ostream &',
os = 'std::ostream &', os = 'std::ostream &',
...@@ -165,6 +277,22 @@ int output_alias_op(const T& x, const std::vector<shape>& shapes) ...@@ -165,6 +277,22 @@ int output_alias_op(const T& x, const std::vector<shape>& shapes)
return !(x == y); return !(x == y);
} }
inline bool is_context_free(const operation& op) { return op.is_context_free(); }
template <class T>
bool is_context_free(const T& x)
{
return is_context_free_op(x);
}
inline bool has_finalize(const operation& op) { return op.has_finalize(); }
template <class T>
bool has_finalize(const T& x)
{
return has_finalize_op(x);
}
#endif #endif
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -71,6 +71,11 @@ struct ${struct_name} ...@@ -71,6 +71,11 @@ struct ${struct_name}
${nonvirtual_members} ${nonvirtual_members}
friend bool is_shared(const ${struct_name} & private_detail_x, const ${struct_name} & private_detail_y)
{
return private_detail_x.private_detail_te_handle_mem_var == private_detail_y.private_detail_te_handle_mem_var;
}
private: private:
struct private_detail_te_handle_base_type struct private_detail_te_handle_base_type
{ {
...@@ -179,7 +184,7 @@ nonvirtual_member = string.Template(''' ...@@ -179,7 +184,7 @@ nonvirtual_member = string.Template('''
${friend} ${return_type} ${name}(${params}) ${const} ${friend} ${return_type} ${name}(${params}) ${const}
{ {
assert(${this}.private_detail_te_handle_mem_var); assert(${this}.private_detail_te_handle_mem_var);
return ${this}.private_detail_te_get_handle().${internal_name}(${member_args}); ${return_} ${this}.private_detail_te_get_handle().${internal_name}(${member_args});
} }
''') ''')
...@@ -189,7 +194,7 @@ virtual_member = string.Template(''' ...@@ -189,7 +194,7 @@ virtual_member = string.Template('''
${return_type} ${internal_name}(${member_params}) ${member_const} override ${return_type} ${internal_name}(${member_params}) ${member_const} override
{ {
${using} ${using}
return ${call}; ${return_} ${call};
} }
''') ''')
...@@ -240,7 +245,8 @@ def convert_member(d, struct_name): ...@@ -240,7 +245,8 @@ def convert_member(d, struct_name):
'friend': '', 'friend': '',
'this': '(*this)', 'this': '(*this)',
'using': '', 'using': '',
'brief': '' 'brief': '',
'return_': ''
} }
args = [] args = []
params = [] params = []
...@@ -257,7 +263,8 @@ def convert_member(d, struct_name): ...@@ -257,7 +263,8 @@ def convert_member(d, struct_name):
for x in d[name]: for x in d[name]:
t = d[name][x] t = d[name][x]
if x == 'return': if x == 'return':
member['return_type'] = t member['return_type'] = t if t else 'void'
if member['return_type'] != 'void': member['return_'] = 'return'
elif x == 'const': elif x == 'const':
member['const'] = 'const' member['const'] = 'const'
member['member_const'] = 'const' member['member_const'] = 'const'
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment