Unverified Commit 59b80d4e authored by Shucai Xiao's avatar Shucai Xiao Committed by GitHub
Browse files

Bool type and equal operator (#603)



* add bool type

* code backup

* code backup

* clang format

* fix build warnings

* clang format

* add the equal operator

* add the equal operator

* clang format

* remove unnecessary code

* refine unit tests

* clang format

* fix review comments and a bug

* clang format

* additional changes

* clang format

* fix cppcheck error

* add bool type in c api

* fix cppcheck error

* fix review comments

* fix cppcheck error

* fix a build error related to gcc

* fix cppcheck error

* fix cppcheck error

* added the equal operator to register list

* add parsing boolean type

* clang format

* fix bool type issue for python output

* clang format

* add support for automatic multibroadcast of the equal operator

* additional unit tests for more code coverage

* clang format

* missing an onnx file
Co-authored-by: default avatarPaul Fultz II <pfultz2@yahoo.com>
parent 3eb4f775
...@@ -72,6 +72,7 @@ register_migraphx_ops( ...@@ -72,6 +72,7 @@ register_migraphx_ops(
div div
dot dot
elu elu
equal
erf erf
exp exp
flatten flatten
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
// Add new types here // Add new types here
// clang-format off // clang-format off
#define MIGRAPHX_SHAPE_VISIT_TYPES(m) \ #define MIGRAPHX_SHAPE_VISIT_TYPES(m) \
m(bool_type, bool) \
m(half_type, half) \ m(half_type, half) \
m(float_type, float) \ m(float_type, float) \
m(double_type, double) \ m(double_type, double) \
......
...@@ -37,6 +37,12 @@ constexpr T normalize(unsigned long z) ...@@ -37,6 +37,12 @@ constexpr T normalize(unsigned long z)
return z % max; return z % max;
} }
template <class T, MIGRAPHX_REQUIRES(std::is_same<T, bool>{})>
constexpr bool normalize(unsigned long z)
{
return static_cast<bool>(z % 2);
}
template <class T> template <class T>
struct xorshf96_generator struct xorshf96_generator
{ {
......
#ifndef MIGRAPHX_GUARD_OPERATORS_EQUAL_HPP
#define MIGRAPHX_GUARD_OPERATORS_EQUAL_HPP
#include <migraphx/op/binary.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/float_equal.hpp>
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct equal : binary<equal>
{
auto apply() const
{
return [](auto x, auto y) { return float_equal(x, y); };
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -29,6 +29,7 @@ ...@@ -29,6 +29,7 @@
#include <migraphx/op/div.hpp> #include <migraphx/op/div.hpp>
#include <migraphx/op/dot.hpp> #include <migraphx/op/dot.hpp>
#include <migraphx/op/elu.hpp> #include <migraphx/op/elu.hpp>
#include <migraphx/op/equal.hpp>
#include <migraphx/op/erf.hpp> #include <migraphx/op/erf.hpp>
#include <migraphx/op/exp.hpp> #include <migraphx/op/exp.hpp>
#include <migraphx/op/flatten.hpp> #include <migraphx/op/flatten.hpp>
......
...@@ -23,6 +23,7 @@ struct shape ...@@ -23,6 +23,7 @@ struct shape
// Add new types here // Add new types here
// clang-format off // clang-format off
#define MIGRAPHX_SHAPE_VISIT_TYPES(m) \ #define MIGRAPHX_SHAPE_VISIT_TYPES(m) \
m(bool_type, bool) \
m(half_type, half) \ m(half_type, half) \
m(float_type, float) \ m(float_type, float) \
m(double_type, double) \ m(double_type, double) \
...@@ -124,43 +125,43 @@ struct shape ...@@ -124,43 +125,43 @@ struct shape
template <class T> template <class T>
struct as struct as
{ {
using type = T; using type = std::conditional_t<std::is_same<T, bool>{}, int8_t, T>;
template <class U> template <class U>
T operator()(U u) const type operator()(U u) const
{ {
return T(u); return type(u);
} }
template <class U> template <class U>
T* operator()(U* u) const type* operator()(U* u) const
{ {
return static_cast<T*>(u); return static_cast<type*>(u);
} }
template <class U> template <class U>
const T* operator()(const U* u) const const type* operator()(const U* u) const
{ {
return static_cast<T*>(u); return static_cast<type*>(u);
} }
T operator()() const { return {}; } type operator()() const { return {}; }
std::size_t size(std::size_t n = 1) const { return sizeof(T) * n; } std::size_t size(std::size_t n = 1) const { return sizeof(type) * n; }
template <class U> template <class U>
T* from(U* buffer, std::size_t n = 0) const type* from(U* buffer, std::size_t n = 0) const
{ {
return reinterpret_cast<T*>(buffer) + n; return reinterpret_cast<type*>(buffer) + n;
} }
template <class U> template <class U>
const T* from(const U* buffer, std::size_t n = 0) const const type* from(const U* buffer, std::size_t n = 0) const
{ {
return reinterpret_cast<const T*>(buffer) + n; return reinterpret_cast<const type*>(buffer) + n;
} }
type_t type_enum() const { return get_type<T>{}; } type_t type_enum() const { return get_type<type>{}; }
}; };
template <class Visitor> template <class Visitor>
......
...@@ -135,6 +135,7 @@ struct onnx_parser ...@@ -135,6 +135,7 @@ struct onnx_parser
add_mem_op("ConvInteger", "quant_convolution", &onnx_parser::parse_conv); add_mem_op("ConvInteger", "quant_convolution", &onnx_parser::parse_conv);
add_mem_op("ConvTranspose", &onnx_parser::parse_conv_transpose); add_mem_op("ConvTranspose", &onnx_parser::parse_conv_transpose);
add_mem_op("Elu", &onnx_parser::parse_elu); add_mem_op("Elu", &onnx_parser::parse_elu);
add_mem_op("Equal", &onnx_parser::parse_equal);
add_mem_op("Expand", &onnx_parser::parse_expand); add_mem_op("Expand", &onnx_parser::parse_expand);
add_mem_op("GatherElements", &onnx_parser::parse_gather_elements); add_mem_op("GatherElements", &onnx_parser::parse_gather_elements);
add_mem_op("Gemm", &onnx_parser::parse_gemm); add_mem_op("Gemm", &onnx_parser::parse_gemm);
...@@ -2398,6 +2399,17 @@ struct onnx_parser ...@@ -2398,6 +2399,17 @@ struct onnx_parser
return prog.add_literal(literal(out_s, out_data)); return prog.add_literal(literal(out_s, out_data));
} }
instruction_ref
parse_equal(const std::string&, const node_info&, std::vector<instruction_ref> args)
{
auto l = add_broadcastable_binary_op(args[0], args[1], "equal");
if(l->get_shape().type() != shape::bool_type)
{
l = prog.add_instruction(op::convert{shape::bool_type}, l);
}
return l;
}
void parse_from(std::istream& is) void parse_from(std::istream& is)
{ {
onnx::ModelProto model; onnx::ModelProto model;
...@@ -2656,12 +2668,13 @@ struct onnx_parser ...@@ -2656,12 +2668,13 @@ struct onnx_parser
case onnx::TensorProto::UINT32: shape_type = shape::uint32_type; break; case onnx::TensorProto::UINT32: shape_type = shape::uint32_type; break;
case onnx::TensorProto::UINT64: shape_type = shape::uint64_type; break; case onnx::TensorProto::UINT64: shape_type = shape::uint64_type; break;
case onnx::TensorProto::UINT8: shape_type = shape::uint8_type; break; case onnx::TensorProto::UINT8: shape_type = shape::uint8_type; break;
case onnx::TensorProto::BOOL: shape_type = shape::bool_type; break;
case onnx::TensorProto::STRING: case onnx::TensorProto::STRING:
case onnx::TensorProto::BOOL:
case onnx::TensorProto::UNDEFINED: case onnx::TensorProto::UNDEFINED:
case onnx::TensorProto::COMPLEX64: case onnx::TensorProto::COMPLEX64:
case onnx::TensorProto::COMPLEX128: case onnx::TensorProto::COMPLEX128:
break; // throw std::runtime_error("Unsupported type"); MIGRAPHX_THROW("PARSE_TYPE: unsupported type" +
std::to_string(t.tensor_type().elem_type()));
} }
if(!input_dims.empty()) if(!input_dims.empty())
...@@ -2706,6 +2719,7 @@ struct onnx_parser ...@@ -2706,6 +2719,7 @@ struct onnx_parser
case 5: return shape::int16_type; case 5: return shape::int16_type;
case 6: return shape::int32_type; case 6: return shape::int32_type;
case 7: return shape::int64_type; case 7: return shape::int64_type;
case 9: return shape::bool_type;
case 10: return shape::half_type; case 10: return shape::half_type;
case 11: return shape::double_type; case 11: return shape::double_type;
case 12: return shape::uint32_type; case 12: return shape::uint32_type;
......
...@@ -82,12 +82,26 @@ py::buffer_info to_buffer_info(T& x) ...@@ -82,12 +82,26 @@ py::buffer_info to_buffer_info(T& x)
strides.begin(), strides.end(), strides.begin(), [&](auto i) { return i * s.type_size(); }); strides.begin(), strides.end(), strides.begin(), [&](auto i) { return i * s.type_size(); });
py::buffer_info b; py::buffer_info b;
visit_type(s, [&](auto as) { visit_type(s, [&](auto as) {
b = py::buffer_info(x.data(), // migraphx use int8_t data to store bool type, we need to
as.size(), // explicitly specify the data type as bool for python
py::format_descriptor<decltype(as())>::format(), if(s.type() == migraphx::shape::bool_type)
s.lens().size(), {
s.lens(), b = py::buffer_info(x.data(),
strides); as.size(),
py::format_descriptor<bool>::format(),
s.lens().size(),
s.lens(),
strides);
}
else
{
b = py::buffer_info(x.data(),
as.size(),
py::format_descriptor<decltype(as())>::format(),
s.lens().size(),
s.lens(),
strides);
}
}); });
return b; return b;
} }
......
...@@ -69,6 +69,7 @@ add_library(migraphx_device ...@@ -69,6 +69,7 @@ add_library(migraphx_device
device/tan.cpp device/tan.cpp
device/tanh.cpp device/tanh.cpp
device/rnn_variable_seq_lens.cpp device/rnn_variable_seq_lens.cpp
device/equal.cpp
) )
set_target_properties(migraphx_device PROPERTIES EXPORT_NAME device) set_target_properties(migraphx_device PROPERTIES EXPORT_NAME device)
rocm_set_soversion(migraphx_device ${MIGRAPHX_SO_VERSION}) rocm_set_soversion(migraphx_device ${MIGRAPHX_SO_VERSION})
...@@ -151,6 +152,7 @@ register_migraphx_gpu_ops(hip_ ...@@ -151,6 +152,7 @@ register_migraphx_gpu_ops(hip_
cosh cosh
cos cos
div div
equal
erf erf
exp exp
floor floor
......
#include <migraphx/gpu/device/equal.hpp>
#include <migraphx/gpu/device/nary.hpp>
#include <migraphx/type_traits.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
template <class T>
__device__ bool equal(T x, T y)
{
auto eps = std::numeric_limits<T>::epsilon();
auto diff = x - y;
return (diff <= eps) and (diff >= -eps);
}
void equal(hipStream_t stream, const argument& result, const argument& arg1, const argument& arg2)
{
nary(stream, result, arg1, arg2)([](auto x, auto y) __device__ { return equal(x, y); });
}
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -16,6 +16,7 @@ rocblas_datatype get_type(shape::type_t type) ...@@ -16,6 +16,7 @@ rocblas_datatype get_type(shape::type_t type)
case shape::uint8_type: return rocblas_datatype_u8_r; case shape::uint8_type: return rocblas_datatype_u8_r;
case shape::int32_type: return rocblas_datatype_i32_r; case shape::int32_type: return rocblas_datatype_i32_r;
case shape::uint32_type: return rocblas_datatype_u32_r; case shape::uint32_type: return rocblas_datatype_u32_r;
case shape::bool_type:
case shape::uint16_type: case shape::uint16_type:
case shape::int16_type: case shape::int16_type:
case shape::int64_type: case shape::int64_type:
......
#ifndef MIGRAPHX_GUARD_RTGLIB_DEVICE_EQUAL_HPP
#define MIGRAPHX_GUARD_RTGLIB_DEVICE_EQUAL_HPP
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
#include <hip/hip_runtime_api.h>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
void equal(hipStream_t stream, const argument& result, const argument& arg1, const argument& arg2);
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_RTGLIB_EQUAL_HPP
#define MIGRAPHX_GUARD_RTGLIB_EQUAL_HPP
#include <migraphx/gpu/oper.hpp>
#include <migraphx/gpu/device/equal.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
struct hip_equal : binary_device<hip_equal, device::equal>
{
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -22,6 +22,7 @@ ...@@ -22,6 +22,7 @@
#include <migraphx/gpu/convolution.hpp> #include <migraphx/gpu/convolution.hpp>
#include <migraphx/gpu/deconvolution.hpp> #include <migraphx/gpu/deconvolution.hpp>
#include <migraphx/gpu/elu.hpp> #include <migraphx/gpu/elu.hpp>
#include <migraphx/gpu/equal.hpp>
#include <migraphx/gpu/gemm.hpp> #include <migraphx/gpu/gemm.hpp>
#include <migraphx/gpu/hip.hpp> #include <migraphx/gpu/hip.hpp>
#include <migraphx/gpu/int8_conv_pack.hpp> #include <migraphx/gpu/int8_conv_pack.hpp>
...@@ -107,6 +108,7 @@ struct miopen_apply ...@@ -107,6 +108,7 @@ struct miopen_apply
add_generic_op("cos"); add_generic_op("cos");
add_generic_op("cosh"); add_generic_op("cosh");
add_generic_op("div"); add_generic_op("div");
add_generic_op("equal");
add_generic_op("erf"); add_generic_op("erf");
add_generic_op("exp"); add_generic_op("exp");
add_generic_op("floor"); add_generic_op("floor");
......
...@@ -2930,4 +2930,45 @@ TEST_CASE(recip_test) ...@@ -2930,4 +2930,45 @@ TEST_CASE(recip_test)
EXPECT(migraphx::verify_range(results_vector, gold)); EXPECT(migraphx::verify_range(results_vector, gold));
} }
TEST_CASE(equal_test)
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {9}};
auto l0 =
p.add_literal(migraphx::literal{s, {1.1, 1.5, 0.1, -1.1, -1.5, -0.6, 0.0, 2.0, -2.0}});
auto l1 =
p.add_literal(migraphx::literal{s, {1.1, 1.6, -0.1, -1.2, -1.5, -0.7, 0.0, 2.3, -2.1}});
auto eq = p.add_instruction(migraphx::op::equal{}, l0, l1);
auto r = p.add_instruction(migraphx::op::convert{migraphx::shape::bool_type}, eq);
p.add_return({r});
p.compile(migraphx::cpu::target{});
auto result = p.eval({}).back();
std::vector<bool> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<bool> gold = {true, false, false, false, true, false, true, false, false};
EXPECT(results_vector == gold);
}
TEST_CASE(equal_brcst_test)
{
migraphx::program p;
migraphx::shape s0{migraphx::shape::float_type, {3, 3}};
auto l0 =
p.add_literal(migraphx::literal{s0, {1.1, 1.5, 0.1, -1.1, -1.5, -0.6, 0.0, 2.0, -2.0}});
migraphx::shape s1{migraphx::shape::float_type, {3, 1}};
auto l1 = p.add_literal(migraphx::literal{s1, {1.1, -1.5, 0.0}});
auto bl1 = p.add_instruction(migraphx::op::multibroadcast{{3, 3}}, l1);
auto eq = p.add_instruction(migraphx::op::equal{}, l0, bl1);
auto r = p.add_instruction(migraphx::op::convert{migraphx::shape::bool_type}, eq);
p.add_return({r});
p.compile(migraphx::cpu::target{});
auto result = p.eval({}).back();
std::vector<bool> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<bool> gold = {true, false, false, false, true, false, true, false, false};
EXPECT(results_vector == gold);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -2857,4 +2857,36 @@ struct test_neg : verify_program<test_neg> ...@@ -2857,4 +2857,36 @@ struct test_neg : verify_program<test_neg>
}; };
}; };
struct test_equal : verify_program<test_equal>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape s{migraphx::shape::double_type, {2, 3, 4, 6}};
auto input1 = p.add_parameter("x", s);
auto input2 = p.add_parameter("y", s);
auto r = p.add_instruction(migraphx::op::equal{}, input1, input2);
p.add_return({r});
return p;
};
};
struct test_equal_brcst : verify_program<test_equal_brcst>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape s0{migraphx::shape::float_type, {3, 3}};
auto l0 = p.add_parameter("x", s0);
migraphx::shape s1{migraphx::shape::float_type, {3, 1}};
auto l1 = p.add_parameter("y", s1);
auto bl1 = p.add_instruction(migraphx::op::multibroadcast{s0.lens()}, l1);
auto r = p.add_instruction(migraphx::op::equal{}, l0, bl1);
p.add_return({r});
return p;
};
};
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -85,6 +85,17 @@ inline std::ostream& operator<<(std::ostream& s, const std::vector<T>& v) ...@@ -85,6 +85,17 @@ inline std::ostream& operator<<(std::ostream& s, const std::vector<T>& v)
return s; return s;
} }
inline std::ostream& operator<<(std::ostream& s, const std::vector<bool>& v)
{
s << "{ ";
for(auto x : v)
{
s << x << ", ";
}
s << "}";
return s;
}
template <class T, class U, class Operator> template <class T, class U, class Operator>
struct expression struct expression
{ {
......
equal_bool_test:

x1bx1"Cast*
to 

bx1
x2y"Equalequal_bool_testZ
x1


Z
x2
 

b
y


B
\ No newline at end of file
...@@ -1111,6 +1111,44 @@ def embedding_bag_offset_test(): ...@@ -1111,6 +1111,44 @@ def embedding_bag_offset_test():
return ([index, offset, node], [weight], [y]) return ([index, offset, node], [weight], [y])
@onnx_test
def equal_test():
ax1 = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
x1 = helper.make_tensor("x1",
data_type=TensorProto.FLOAT,
dims=(2, 3),
vals=ax1.astype(np.float32))
x2 = helper.make_tensor_value_info('x2', TensorProto.FLOAT, [2, 3])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [2, 3])
node = onnx.helper.make_node(
'Equal',
inputs=['x1', 'x2'],
outputs=['y'],
)
return ([node], [x2], [y], [x1])
@onnx_test
def equal_bool_test():
x1 = helper.make_tensor_value_info('x1', TensorProto.FLOAT, [2, 3])
x2 = helper.make_tensor_value_info('x2', TensorProto.BOOL, [2, 3])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [2, 3])
node1 = onnx.helper.make_node('Cast', inputs=['x1'], outputs=['bx1'], to=9)
node2 = onnx.helper.make_node(
'Equal',
inputs=['bx1', 'x2'],
outputs=['y'],
)
return ([node1, node2], [x1, x2], [y])
@onnx_test @onnx_test
def erf_test(): def erf_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [10, 15]) x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [10, 15])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment