Unverified Commit 85ed5718 authored by Shucai Xiao's avatar Shucai Xiao Committed by GitHub
Browse files

Pow different data type (#707)



* add support of the different input data_type for the pow operator

* clang format

* fix cppcheck error

* clang format

* add unit test for the pow operator with different input data types

* clang format

* remove unnecessary comments

* fix review comments

* clang format

* fix a issue related to hash table key type

* clang format
Co-authored-by: default avatarmvermeulen <5479696+mvermeulen@users.noreply.github.com>
parent ceb4ca09
......@@ -11,12 +11,7 @@ struct parse_binary_op : op_parser<parse_binary_op>
{
std::vector<op_desc> operators() const
{
return {{"Add", "add"},
{"Div", "div"},
{"Mul", "mul"},
{"Pow", "pow"},
{"PRelu", "prelu"},
{"Sub", "sub"}};
return {{"Add", "add"}, {"Div", "div"}, {"Mul", "mul"}, {"PRelu", "prelu"}, {"Sub", "sub"}};
}
instruction_ref parse(const op_desc& opd,
......
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
auto compute_type(shape::type_t t1, shape::type_t t2)
{
const static std::unordered_map<int, int> op_order = {
{static_cast<int>(shape::int8_type), 1},
{static_cast<int>(shape::uint8_type), 2},
{static_cast<int>(shape::int16_type), 3},
{static_cast<int>(shape::uint16_type), 4},
{static_cast<int>(shape::int32_type), 5},
{static_cast<int>(shape::uint32_type), 6},
{static_cast<int>(shape::int64_type), 7},
{static_cast<int>(shape::uint64_type), 8},
{static_cast<int>(shape::half_type), 9},
{static_cast<int>(shape::float_type), 10},
{static_cast<int>(shape::double_type), 11}};
int it1 = static_cast<int>(t1);
int it2 = static_cast<int>(t2);
if(!contains(op_order, it1) or !contains(op_order, it2))
{
MIGRAPHX_THROW("PARSE_POW: Input data type not supported!");
}
return ((op_order.at(it1) >= op_order.at(it2)) ? t1 : t2);
}
struct parse_pow : op_parser<parse_pow>
{
std::vector<op_desc> operators() const { return {{"Pow"}}; }
instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& /*parser*/,
const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const
{
auto type_base = args[0]->get_shape().type();
auto type_exponent = args[1]->get_shape().type();
auto type_compute = compute_type(type_base, type_exponent);
if(type_compute != type_base)
{
args[0] =
info.add_instruction(make_op("convert", {{"target_type", type_compute}}), args[0]);
}
if(type_compute != type_exponent)
{
args[1] =
info.add_instruction(make_op("convert", {{"target_type", type_compute}}), args[1]);
}
auto ret = info.add_broadcastable_binary_op("pow", args[0], args[1]);
if(type_compute != type_base)
{
ret = info.add_instruction(make_op("convert", {{"target_type", type_base}}), ret);
}
return ret;
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -2114,6 +2114,38 @@ def pow_test():
return ([node], [arg0, arg1], [arg_out])
@onnx_test
def pow_fp32_i64_test():
arg0 = helper.make_tensor_value_info('0', TensorProto.FLOAT, [2, 3, 4, 5])
arg1 = helper.make_tensor_value_info('1', TensorProto.INT64, [2, 3, 4, 5])
arg_out = helper.make_tensor_value_info('out', TensorProto.FLOAT,
[2, 3, 4, 5])
node = onnx.helper.make_node(
'Pow',
inputs=['0', '1'],
outputs=['out'],
)
return ([node], [arg0, arg1], [arg_out])
@onnx_test
def pow_i64_fp32_test():
arg0 = helper.make_tensor_value_info('0', TensorProto.INT64, [2, 3, 4, 5])
arg1 = helper.make_tensor_value_info('1', TensorProto.FLOAT, [2, 3, 4, 5])
arg_out = helper.make_tensor_value_info('out', TensorProto.INT64,
[2, 3, 4, 5])
node = onnx.helper.make_node(
'Pow',
inputs=['0', '1'],
outputs=['out'],
)
return ([node], [arg0, arg1], [arg_out])
@onnx_test
def prelu_brcst_test():
arg0 = helper.make_tensor_value_info('0', TensorProto.FLOAT, [2, 3, 4, 5])
......
......@@ -1905,6 +1905,40 @@ TEST_CASE(pow_test)
EXPECT(p == prog);
}
TEST_CASE(pow_fp32_i64_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::int64_type, {2, 3, 4, 5}});
auto l1f = mm->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::float_type}}), l1);
auto ret = mm->add_instruction(migraphx::make_op("pow"), l0, l1f);
mm->add_return({ret});
auto prog = migraphx::parse_onnx("pow_fp32_i64_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(pow_i64_fp32_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::int64_type, {2, 3, 4, 5}});
auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
auto l0f = mm->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::float_type}}), l0);
auto fr = mm->add_instruction(migraphx::make_op("pow"), l0f, l1);
auto ir = mm->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::int64_type}}), fr);
mm->add_return({ir});
auto prog = migraphx::parse_onnx("pow_i64_fp32_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(prelu_brcst_test)
{
migraphx::program p;
......
pow_fp32_i64_test:~

0
1out"Powpow_fp32_i64_testZ
0




Z
1




b
out




B
\ No newline at end of file
pow_i64_fp32_test:~

0
1out"Powpow_i64_fp32_testZ
0




Z
1




b
out




B
\ No newline at end of file
......@@ -244,7 +244,6 @@ def create_backend_test(testname=None, target_device=None):
backend_test.exclude(r'test_not_2d_cpu')
backend_test.exclude(r'test_not_3d_cpu')
backend_test.exclude(r'test_not_4d_cpu')
backend_test.exclude(r'test_pow_types_*')
backend_test.exclude(r'test_size_cpu')
backend_test.exclude(r'test_size_example_cpu')
backend_test.exclude(r'test_softmax_cross_entropy_*')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment