Commit 9c5f6324 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

Merge branch 'develop' of github.com:ROCmSoftwarePlatform/AMDMIGraphX into keep_std_shape

parents 90f10299 332cb710
...@@ -35,7 +35,7 @@ struct argmax ...@@ -35,7 +35,7 @@ struct argmax
shape normalize_compute_shape(std::vector<shape> inputs) const shape normalize_compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(1).standard(); check_shapes{inputs, *this}.has(1);
auto lens = inputs[0].lens(); auto lens = inputs[0].lens();
lens[axis] = 1; lens[axis] = 1;
......
...@@ -35,7 +35,7 @@ struct argmin ...@@ -35,7 +35,7 @@ struct argmin
shape normalize_compute_shape(std::vector<shape> inputs) const shape normalize_compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(1).standard(); check_shapes{inputs, *this}.has(1);
auto lens = inputs[0].lens(); auto lens = inputs[0].lens();
lens[axis] = 1; lens[axis] = 1;
......
...@@ -10,20 +10,27 @@ namespace onnx { ...@@ -10,20 +10,27 @@ namespace onnx {
struct parse_hardsigmoid : op_parser<parse_hardsigmoid> struct parse_hardsigmoid : op_parser<parse_hardsigmoid>
{ {
std::vector<op_desc> operators() const { return {{"HardSigmoid"}}; } std::vector<op_desc> operators() const { return {{"HardSigmoid"}, {"HardSwish"}}; }
instruction_ref parse(const op_desc& /*opd*/, instruction_ref parse(const op_desc& opd,
const onnx_parser& /*parser*/, const onnx_parser& /*parser*/,
const onnx_parser::node_info& info, const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const std::vector<instruction_ref> args) const
{ {
float alpha = 0.2; float alpha = 0.2;
float beta = 0.5; float beta = 0.5;
if(contains(info.attributes, "alpha")) if(opd.onnx_name == "HardSwish")
alpha = info.attributes.at("alpha").f(); {
alpha = 1.0 / 6.0;
}
else
{
if(contains(info.attributes, "alpha"))
alpha = info.attributes.at("alpha").f();
if(contains(info.attributes, "beta")) if(contains(info.attributes, "beta"))
beta = info.attributes.at("beta").f(); beta = info.attributes.at("beta").f();
}
auto input_lens = args[0]->get_shape().lens(); auto input_lens = args[0]->get_shape().lens();
auto input_type = args[0]->get_shape().type(); auto input_type = args[0]->get_shape().type();
...@@ -40,9 +47,13 @@ struct parse_hardsigmoid : op_parser<parse_hardsigmoid> ...@@ -40,9 +47,13 @@ struct parse_hardsigmoid : op_parser<parse_hardsigmoid>
migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}), migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}),
info.add_literal(migraphx::literal{migraphx::shape{input_type}, {1}})); info.add_literal(migraphx::literal{migraphx::shape{input_type}, {1}}));
auto mul = info.add_instruction(migraphx::make_op("mul"), mb_alpha, args[0]); auto mul = info.add_instruction(migraphx::make_op("mul"), mb_alpha, args[0]);
auto add = info.add_instruction(migraphx::make_op("add"), mb_beta, mul); auto add = info.add_instruction(migraphx::make_op("add"), mb_beta, mul);
return info.add_instruction(migraphx::make_op("clip"), add, mb_zero, mb_one); auto hardsigmoid = info.add_instruction(migraphx::make_op("clip"), add, mb_zero, mb_one);
if(opd.onnx_name == "HardSwish")
return info.add_instruction(migraphx::make_op("mul"), args[0], hardsigmoid);
return hardsigmoid;
} }
}; };
......
...@@ -9,7 +9,7 @@ namespace gpu { ...@@ -9,7 +9,7 @@ namespace gpu {
shape hip_argmax::compute_shape(const std::vector<shape>& inputs) const shape hip_argmax::compute_shape(const std::vector<shape>& inputs) const
{ {
check_shapes{inputs, *this}.has(2).standard(); check_shapes{inputs, *this}.has(2);
return op.normalize_compute_shape({inputs.at(0)}); return op.normalize_compute_shape({inputs.at(0)});
} }
......
...@@ -9,7 +9,7 @@ namespace gpu { ...@@ -9,7 +9,7 @@ namespace gpu {
shape hip_argmin::compute_shape(const std::vector<shape>& inputs) const shape hip_argmin::compute_shape(const std::vector<shape>& inputs) const
{ {
check_shapes{inputs, *this}.has(2).standard(); check_shapes{inputs, *this}.has(2);
return op.normalize_compute_shape({inputs.at(0)}); return op.normalize_compute_shape({inputs.at(0)});
} }
......
...@@ -76,8 +76,9 @@ void arg_op(Op op, hipStream_t stream, const argument& result, const argument& a ...@@ -76,8 +76,9 @@ void arg_op(Op op, hipStream_t stream, const argument& result, const argument& a
size_t batch_item_num = batch_lens[axis]; size_t batch_item_num = batch_lens[axis];
batch_lens[axis] = 1; batch_lens[axis] = 1;
migraphx::shape batch_shape{arg_shape.type(), batch_lens}; migraphx::shape batch_shape{arg_shape.type(), batch_lens};
migraphx::shape std_arg_shape{arg_shape.type(), arg_shape.lens()};
hip_visit_all(arg, arg_shape, batch_shape)([&](auto input, auto arg_s, auto batch_s) { hip_visit_all(arg, std_arg_shape, batch_shape)([&](auto input, auto arg_s, auto batch_s) {
auto* output = device_cast(result.get<int64_t>().data()); auto* output = device_cast(result.get<int64_t>().data());
using type = device_type<std::remove_cv_t<typename decltype(input)::value_type>>; using type = device_type<std::remove_cv_t<typename decltype(input)::value_type>>;
// use one block for items in one batch. // use one block for items in one batch.
......
...@@ -1694,6 +1694,16 @@ def hardsigmoid_verify_test(): ...@@ -1694,6 +1694,16 @@ def hardsigmoid_verify_test():
return ([node], [x], [y]) return ([node], [x], [y])
@onnx_test
def hardswish_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [2, 5])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [2, 5])
node = onnx.helper.make_node('HardSwish', inputs=['x'], outputs=['y'])
return ([node], [x], [y])
@onnx_test @onnx_test
def if_else_test(): def if_else_test():
x = onnx.helper.make_tensor_value_info('x', onnx.TensorProto.FLOAT, [2, 3]) x = onnx.helper.make_tensor_value_info('x', onnx.TensorProto.FLOAT, [2, 3])
......
hardswish_test:M

xy" HardSwishhardswish_testZ
x


b
y


B
\ No newline at end of file
...@@ -1687,6 +1687,41 @@ TEST_CASE(hardsigmoid_half_test) ...@@ -1687,6 +1687,41 @@ TEST_CASE(hardsigmoid_half_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(hardswish_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
std::vector<std::size_t> input_lens{2, 5};
auto input_type = migraphx::shape::float_type;
migraphx::shape s{input_type, input_lens};
auto x = mm->add_parameter("x", s);
float alpha = 1.0 / 6.0;
float beta = 0.5;
auto mb_alpha = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}),
mm->add_literal(migraphx::literal{migraphx::shape{input_type}, {alpha}}));
auto mb_beta = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}),
mm->add_literal(migraphx::literal{migraphx::shape{input_type}, {beta}}));
auto mb_zero =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}),
mm->add_literal(migraphx::literal{migraphx::shape{input_type}, {0}}));
auto mb_one =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}),
mm->add_literal(migraphx::literal{migraphx::shape{input_type}, {1}}));
auto mul = mm->add_instruction(migraphx::make_op("mul"), mb_alpha, x);
auto add = mm->add_instruction(migraphx::make_op("add"), mb_beta, mul);
auto hardsigmoid = mm->add_instruction(migraphx::make_op("clip"), add, mb_zero, mb_one);
mm->add_instruction(migraphx::make_op("mul"), x, hardsigmoid);
auto prog = optimize_onnx("hardswish_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(if_else_test) TEST_CASE(if_else_test)
{ {
migraphx::program p; migraphx::program p;
......
...@@ -119,6 +119,7 @@ def create_backend_test(testname=None, target_device=None): ...@@ -119,6 +119,7 @@ def create_backend_test(testname=None, target_device=None):
backend_test.include(r'.*test_globalmaxpool.*') backend_test.include(r'.*test_globalmaxpool.*')
backend_test.include(r'.*test_greater.*') backend_test.include(r'.*test_greater.*')
backend_test.include(r'.*test_hardsigmoid.*') backend_test.include(r'.*test_hardsigmoid.*')
backend_test.include(r'.*test_hardswish.*')
backend_test.include(r'.*test_identity.*') backend_test.include(r'.*test_identity.*')
backend_test.include(r'.*test_if.*') backend_test.include(r'.*test_if.*')
backend_test.include(r'.*test_LeakyReLU*') backend_test.include(r'.*test_LeakyReLU*')
......
#include <iostream>
#include <vector>
#include <migraphx/literal.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/ref/target.hpp>
#include <migraphx/verify.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/auto_contiguous.hpp>
#include <migraphx/pass_manager.hpp>
#include "test.hpp"
TEST_CASE(argmax_test_nonstd_shape)
{
migraphx::program p;
auto* mm = p.get_main_module();
std::vector<float> data = {1.2255, 1.6834, -2.0305, -0.3221, 0.4701, 0.2583, 0.7545, 2.5758,
-1.6849, 0.0928, 0.9022, -0.8765, -0.4090, 0.9301, 2.0724, -1.5706,
0.4867, -0.1493, 0.6957, -0.2179, 0.7142, 0.7177, 0.0183, 1.3497};
migraphx::shape data_shape{migraphx::shape::float_type, {2, 3, 4}};
auto dl = mm->add_literal(migraphx::literal{data_shape, data});
auto dl_trans =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 2, 0}}}), dl);
mm->add_instruction(migraphx::make_op("argmax", {{"axis", -3}}), dl_trans);
auto p_uncompiled = p;
p.compile(migraphx::ref::target{});
auto result = p.eval({}).back();
auto res_gold = p_uncompiled.eval({}).back();
std::vector<int64_t> result_vec;
result.visit([&](auto output) { result_vec.assign(output.begin(), output.end()); });
std::vector<int64_t> res_gold_vec;
res_gold.visit([&](auto output) { res_gold_vec.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(result_vec, res_gold_vec));
}
TEST_CASE(argmin_test_nonstd_shape)
{
migraphx::program p;
auto* mm = p.get_main_module();
std::vector<float> data = {1.2255, 1.6834, -2.0305, -0.3221, 0.4701, 0.2583, 0.7545, 2.5758,
-1.6849, 0.0928, 0.9022, -0.8765, -0.4090, 0.9301, 2.0724, -1.5706,
0.4867, -0.1493, 0.6957, -0.2179, 0.7142, 0.7177, 0.0183, 1.3497};
migraphx::shape data_shape{migraphx::shape::float_type, {2, 3, 4}};
auto dl = mm->add_literal(migraphx::literal{data_shape, data});
auto dl_trans =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 2, 0}}}), dl);
mm->add_instruction(migraphx::make_op("argmin", {{"axis", -1}}), dl_trans);
auto p_uncompiled = p;
p.compile(migraphx::ref::target{});
auto result = p.eval({}).back();
auto res_gold = p_uncompiled.eval({}).back();
std::vector<int64_t> result_vec;
result.visit([&](auto output) { result_vec.assign(output.begin(), output.end()); });
std::vector<int64_t> res_gold_vec;
res_gold.visit([&](auto output) { res_gold_vec.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(result_vec, res_gold_vec));
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -2,34 +2,92 @@ ...@@ -2,34 +2,92 @@
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/op/argmax.hpp> #include <migraphx/op/argmax.hpp>
#include <migraphx/op/argmin.hpp> #include <migraphx/op/argmin.hpp>
template <class T, int Axis> template <class T, int Axis, int NonStdShape>
struct test_arg_ops : verify_program<test_arg_ops<T, Axis>> struct test_arg_ops : verify_program<test_arg_ops<T, Axis, NonStdShape>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {2, 3, 4, 1025}}; migraphx::shape s{migraphx::shape::float_type, {2, 1, 4, 1025}};
auto param = mm->add_parameter("data", s); auto param = mm->add_parameter("data", s);
switch(NonStdShape)
{
case 0:
param = mm->add_instruction(
migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), param);
break;
case 1:
param = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {2, 3, 4, 1025}}}), param);
break;
case 2:
param = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {1}}, {"ends", {3}}}), param);
break;
default: break;
}
mm->add_instruction(T{Axis}, param); mm->add_instruction(T{Axis}, param);
return p; return p;
} }
}; };
// transpose argmax tests
template struct test_arg_ops<migraphx::op::argmax, 0>; template struct test_arg_ops<migraphx::op::argmax, 0, 0>;
template struct test_arg_ops<migraphx::op::argmax, 1>; template struct test_arg_ops<migraphx::op::argmax, 1, 0>;
template struct test_arg_ops<migraphx::op::argmax, 2>; template struct test_arg_ops<migraphx::op::argmax, 2, 0>;
template struct test_arg_ops<migraphx::op::argmax, 3>; template struct test_arg_ops<migraphx::op::argmax, 3, 0>;
template struct test_arg_ops<migraphx::op::argmax, -1>; template struct test_arg_ops<migraphx::op::argmax, -1, 0>;
template struct test_arg_ops<migraphx::op::argmax, -2>; template struct test_arg_ops<migraphx::op::argmax, -2, 0>;
// transpose argmin tests
template struct test_arg_ops<migraphx::op::argmin, 0>; template struct test_arg_ops<migraphx::op::argmin, 0, 0>;
template struct test_arg_ops<migraphx::op::argmin, 1>; template struct test_arg_ops<migraphx::op::argmin, 1, 0>;
template struct test_arg_ops<migraphx::op::argmin, 2>; template struct test_arg_ops<migraphx::op::argmin, 2, 0>;
template struct test_arg_ops<migraphx::op::argmin, 3>; template struct test_arg_ops<migraphx::op::argmin, 3, 0>;
template struct test_arg_ops<migraphx::op::argmin, -3>; template struct test_arg_ops<migraphx::op::argmin, -3, 0>;
template struct test_arg_ops<migraphx::op::argmin, -4>; template struct test_arg_ops<migraphx::op::argmin, -4, 0>;
// broadcast argmax tests
template struct test_arg_ops<migraphx::op::argmax, 0, 1>;
template struct test_arg_ops<migraphx::op::argmax, 1, 1>;
template struct test_arg_ops<migraphx::op::argmax, 2, 1>;
template struct test_arg_ops<migraphx::op::argmax, 3, 1>;
template struct test_arg_ops<migraphx::op::argmax, -1, 1>;
template struct test_arg_ops<migraphx::op::argmax, -2, 1>;
// broadcast argmin tests
template struct test_arg_ops<migraphx::op::argmin, 0, 1>;
template struct test_arg_ops<migraphx::op::argmin, 1, 1>;
template struct test_arg_ops<migraphx::op::argmin, 2, 1>;
template struct test_arg_ops<migraphx::op::argmin, 3, 1>;
template struct test_arg_ops<migraphx::op::argmin, -3, 1>;
template struct test_arg_ops<migraphx::op::argmin, -4, 1>;
// slice argmax tests
template struct test_arg_ops<migraphx::op::argmax, 0, 2>;
template struct test_arg_ops<migraphx::op::argmax, 1, 2>;
template struct test_arg_ops<migraphx::op::argmax, 2, 2>;
template struct test_arg_ops<migraphx::op::argmax, 3, 2>;
template struct test_arg_ops<migraphx::op::argmax, -1, 2>;
template struct test_arg_ops<migraphx::op::argmax, -2, 2>;
// slice argmin tests
template struct test_arg_ops<migraphx::op::argmin, 0, 2>;
template struct test_arg_ops<migraphx::op::argmin, 1, 2>;
template struct test_arg_ops<migraphx::op::argmin, 2, 2>;
template struct test_arg_ops<migraphx::op::argmin, 3, 2>;
template struct test_arg_ops<migraphx::op::argmin, -3, 2>;
template struct test_arg_ops<migraphx::op::argmin, -4, 2>;
// default case, standard shape argmax tests
template struct test_arg_ops<migraphx::op::argmax, 0, 3>;
template struct test_arg_ops<migraphx::op::argmax, 1, 3>;
template struct test_arg_ops<migraphx::op::argmax, 2, 3>;
template struct test_arg_ops<migraphx::op::argmax, 3, 3>;
template struct test_arg_ops<migraphx::op::argmax, -1, 3>;
template struct test_arg_ops<migraphx::op::argmax, -2, 3>;
// default case, standard shape argmin tests
template struct test_arg_ops<migraphx::op::argmin, 0, 3>;
template struct test_arg_ops<migraphx::op::argmin, 1, 3>;
template struct test_arg_ops<migraphx::op::argmin, 2, 3>;
template struct test_arg_ops<migraphx::op::argmin, 3, 3>;
template struct test_arg_ops<migraphx::op::argmin, -3, 3>;
template struct test_arg_ops<migraphx::op::argmin, -4, 3>;
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment