Commit 7e931a37 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

merge changes from round_operator

parents 1b4216ca dc2e2cf3
#ifndef MIGRAPHX_GUARD_OPERATORS_ROUND_HPP
#define MIGRAPHX_GUARD_OPERATORS_ROUND_HPP
#include <migraphx/op/unary.hpp>
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct round : unary<round>
{
auto apply() const
{
return [](auto x) { return std::round(x); };
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
......@@ -56,6 +56,7 @@
#include <migraphx/op/rnn.hpp>
#include <migraphx/op/rnn_last_cell_output.hpp>
#include <migraphx/op/rnn_last_output.hpp>
#include <migraphx/op/round.hpp>
#include <migraphx/op/rsqrt.hpp>
#include <migraphx/op/scalar.hpp>
#include <migraphx/op/sigmoid.hpp>
......
......@@ -55,6 +55,7 @@ struct onnx_parser
add_generic_op("Acos", op::acos{});
add_generic_op("Atan", op::atan{});
add_generic_op("Sqrt", op::sqrt{});
add_generic_op("Round", op::round{});
add_binary_op("Add", op::add{});
add_binary_op("Div", op::div{});
......
......@@ -198,14 +198,12 @@ void quantize_int8(program& prog,
}
else
{
quant_input = insert_quant_ins(
prog, input, quant_type, map_quant_ins, param.first, param.second);
quant_input = insert_quant_ins(prog, input, quant_type, map_quant_ins);
}
}
else
{
quant_input = insert_quant_ins(
prog, input, quant_type, map_quant_ins, param.first, param.second);
quant_input = insert_quant_ins(prog, input, quant_type, map_quant_ins);
}
converted_inputs.push_back(quant_input);
}
......
......@@ -42,6 +42,7 @@ add_library(migraphx_device
device/clip.cpp
device/reduce_sum.cpp
device/rsqrt.cpp
device/round.cpp
device/sqrt.cpp
device/reduce_mean.cpp
device/pow.cpp
......
#include <migraphx/gpu/device/round.hpp>
#include <migraphx/gpu/device/nary.hpp>
#include <migraphx/gpu/device/types.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
void round(hipStream_t stream, const argument& result, const argument& arg)
{
nary(stream, result, arg)([](auto x) { return ::round(to_hip_type(x)); });
}
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#ifndef MIGRAPHX_GUARD_RTGLIB_DEVICE_ROUND_HPP
#define MIGRAPHX_GUARD_RTGLIB_DEVICE_ROUND_HPP
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
#include <hip/hip_runtime_api.h>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
void round(hipStream_t stream, const argument& result, const argument& arg);
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_RTGLIB_ROUND_HPP
#define MIGRAPHX_GUARD_RTGLIB_ROUND_HPP
#include <migraphx/gpu/oper.hpp>
#include <migraphx/gpu/device/round.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
struct hip_round : unary_device<hip_round, device::round>
{
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
......@@ -54,6 +54,7 @@
#include <migraphx/gpu/convert.hpp>
#include <migraphx/gpu/clip.hpp>
#include <migraphx/gpu/reduce_sum.hpp>
#include <migraphx/gpu/round.hpp>
#include <migraphx/gpu/rsqrt.hpp>
#include <migraphx/gpu/sqrt.hpp>
#include <migraphx/gpu/reduce_mean.hpp>
......@@ -112,6 +113,7 @@ struct miopen_apply
add_generic_op<hip_max>("max");
add_generic_op<hip_min>("min");
add_generic_op<hip_rsqrt>("rsqrt");
add_generic_op<hip_round>("round");
add_generic_op<hip_pow>("pow");
add_generic_op<hip_sqdiff>("sqdiff");
......
......@@ -2047,6 +2047,25 @@ TEST_CASE(op_capture)
res.visit([&](auto output) { vec.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(vec, cap_vec));
};
}
TEST_CASE(round_test)
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {9}};
auto l = p.add_literal(migraphx::literal{s, {1.1, 1.5, 1.6, -1.1, -1.5, -1.6, 0.0, 2.0, -2.0}});
p.add_instruction(migraphx::op::round{}, l);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
for(auto v : results_vector)
{
std::cout << v << "\t";
}
std::cout << std::endl;
std::vector<float> gold = {1.0, 2.0, 2.0, -1.0, -2.0, -2.0, 0.0, 2.0, -2.0};
EXPECT(migraphx::verify_range(results_vector, gold));
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
......@@ -3822,4 +3822,17 @@ struct test_convert : verify_program<test_convert>
};
};
struct test_round : verify_program<test_round>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {2, 3, 4, 6}};
auto param = p.add_parameter("x", s);
p.add_instruction(migraphx::op::round{}, param);
return p;
};
};
int main(int argc, const char* argv[]) { test::run(argc, argv); }
......@@ -1012,4 +1012,14 @@ TEST_CASE(expand_test)
EXPECT(p == prog);
}
TEST_CASE(round_test)
{
migraphx::program p;
auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::double_type, {10, 5}});
p.add_instruction(migraphx::op::round{}, input);
auto prog = migraphx::parse_onnx("round_test.onnx");
EXPECT(p == prog);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
 round-example:E
xy"Round
test_roundZ
x
 

b
y
 

B
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment