Commit 8c1f0289 authored by Khalique's avatar Khalique
Browse files

Merge branch 'test_gen_scripts' of...

Merge branch 'test_gen_scripts' of https://github.com/ROCmSoftwarePlatform/AMDMIGraphX into test_gen_scripts
parents 4295961c d41ae6ed
...@@ -4,6 +4,8 @@ ...@@ -4,6 +4,8 @@
#include <migraphx/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <migraphx/op/contiguous.hpp>
#include <migraphx/op/identity.hpp>
#include <utility> #include <utility>
namespace migraphx { namespace migraphx {
...@@ -82,6 +84,14 @@ void eliminate_contiguous::apply(program& p) const ...@@ -82,6 +84,14 @@ void eliminate_contiguous::apply(program& p) const
{ {
instruction::replace_argument(ins, arg, prev); instruction::replace_argument(ins, arg, prev);
} }
else if(prev->can_eval())
{
auto c = op::contiguous{};
auto r = c.compute(c.compute_shape({prev->get_shape()}), {prev->eval()});
auto l = p.add_literal(r.get_shape(), r.data());
p.replace_instruction(arg, l);
}
} }
} }
} }
......
#ifndef MIGRAPHX_GUARD_OPERATORS_ROUND_HPP
#define MIGRAPHX_GUARD_OPERATORS_ROUND_HPP
#include <migraphx/op/unary.hpp>
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct round : unary<round>
{
auto apply() const
{
return [](auto x) { return std::round(x); };
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -56,6 +56,7 @@ ...@@ -56,6 +56,7 @@
#include <migraphx/op/rnn.hpp> #include <migraphx/op/rnn.hpp>
#include <migraphx/op/rnn_last_cell_output.hpp> #include <migraphx/op/rnn_last_cell_output.hpp>
#include <migraphx/op/rnn_last_output.hpp> #include <migraphx/op/rnn_last_output.hpp>
#include <migraphx/op/round.hpp>
#include <migraphx/op/rsqrt.hpp> #include <migraphx/op/rsqrt.hpp>
#include <migraphx/op/scalar.hpp> #include <migraphx/op/scalar.hpp>
#include <migraphx/op/sigmoid.hpp> #include <migraphx/op/sigmoid.hpp>
......
...@@ -55,6 +55,7 @@ struct onnx_parser ...@@ -55,6 +55,7 @@ struct onnx_parser
add_generic_op("Acos", op::acos{}); add_generic_op("Acos", op::acos{});
add_generic_op("Atan", op::atan{}); add_generic_op("Atan", op::atan{});
add_generic_op("Sqrt", op::sqrt{}); add_generic_op("Sqrt", op::sqrt{});
add_generic_op("Round", op::round{});
add_generic_op("Sign", op::sign{}); add_generic_op("Sign", op::sign{});
add_binary_op("Add", op::add{}); add_binary_op("Add", op::add{});
...@@ -523,15 +524,7 @@ struct onnx_parser ...@@ -523,15 +524,7 @@ struct onnx_parser
if(contains(attributes, "ends")) if(contains(attributes, "ends"))
{ {
literal s = parse_value(attributes.at("ends")); op.ends = get_indices(attributes.at("ends"));
s.visit([&](auto v) { copy(v, std::back_inserter(op.ends)); });
for(size_t i = 0; i < num_dims; i++)
{
if(static_cast<size_t>(op.ends[i]) > dims[i])
{
op.ends[i] = dims[i];
}
}
} }
if(contains(attributes, "starts")) if(contains(attributes, "starts"))
{ {
...@@ -1541,6 +1534,20 @@ struct onnx_parser ...@@ -1541,6 +1534,20 @@ struct onnx_parser
return result; return result;
} }
static std::vector<int64_t> get_indices(const onnx::AttributeProto& attr)
{
std::vector<int64_t> result;
literal s = parse_value(attr);
s.visit([&](auto v) { copy(v, std::back_inserter(result)); });
// Clamp large indices to -1
std::replace_if(
result.begin(),
result.end(),
[](auto x) { return x > int64_t{std::numeric_limits<std::int32_t>::max()} / 2; },
-1);
return result;
}
template <class T> template <class T>
static literal from_repeated(shape::type_t t, const T& r) static literal from_repeated(shape::type_t t, const T& r)
{ {
......
...@@ -46,6 +46,7 @@ add_library(migraphx_device ...@@ -46,6 +46,7 @@ add_library(migraphx_device
device/clip.cpp device/clip.cpp
device/reduce_sum.cpp device/reduce_sum.cpp
device/rsqrt.cpp device/rsqrt.cpp
device/round.cpp
device/sqrt.cpp device/sqrt.cpp
device/reduce_mean.cpp device/reduce_mean.cpp
device/pow.cpp device/pow.cpp
......
#include <migraphx/gpu/device/round.hpp>
#include <migraphx/gpu/device/nary.hpp>
#include <migraphx/gpu/device/types.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
void round(hipStream_t stream, const argument& result, const argument& arg)
{
nary(stream, result, arg)([](auto x) { return ::round(to_hip_type(x)); });
}
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#ifndef MIGRAPHX_GUARD_RTGLIB_DEVICE_ROUND_HPP
#define MIGRAPHX_GUARD_RTGLIB_DEVICE_ROUND_HPP
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
#include <hip/hip_runtime_api.h>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
void round(hipStream_t stream, const argument& result, const argument& arg);
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_RTGLIB_ROUND_HPP
#define MIGRAPHX_GUARD_RTGLIB_ROUND_HPP
#include <migraphx/gpu/oper.hpp>
#include <migraphx/gpu/device/round.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
struct hip_round : unary_device<hip_round, device::round>
{
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -55,6 +55,7 @@ ...@@ -55,6 +55,7 @@
#include <migraphx/gpu/convert.hpp> #include <migraphx/gpu/convert.hpp>
#include <migraphx/gpu/clip.hpp> #include <migraphx/gpu/clip.hpp>
#include <migraphx/gpu/reduce_sum.hpp> #include <migraphx/gpu/reduce_sum.hpp>
#include <migraphx/gpu/round.hpp>
#include <migraphx/gpu/rsqrt.hpp> #include <migraphx/gpu/rsqrt.hpp>
#include <migraphx/gpu/sqrt.hpp> #include <migraphx/gpu/sqrt.hpp>
#include <migraphx/gpu/reduce_mean.hpp> #include <migraphx/gpu/reduce_mean.hpp>
...@@ -86,6 +87,7 @@ struct miopen_apply ...@@ -86,6 +87,7 @@ struct miopen_apply
void init() void init()
{ {
this->last = instruction::get_output_alias(std::prev(prog->end())); this->last = instruction::get_output_alias(std::prev(prog->end()));
add_miopen_simple_op<miopen_abs>("abs", make_abs); add_miopen_simple_op<miopen_abs>("abs", make_abs);
add_miopen_extend_op<miopen_leaky_relu, op::leaky_relu>("leaky_relu", make_leaky_relu); add_miopen_extend_op<miopen_leaky_relu, op::leaky_relu>("leaky_relu", make_leaky_relu);
...@@ -111,6 +113,7 @@ struct miopen_apply ...@@ -111,6 +113,7 @@ struct miopen_apply
add_generic_op<hip_max>("max"); add_generic_op<hip_max>("max");
add_generic_op<hip_min>("min"); add_generic_op<hip_min>("min");
add_generic_op<hip_rsqrt>("rsqrt"); add_generic_op<hip_rsqrt>("rsqrt");
add_generic_op<hip_round>("round");
add_generic_op<hip_pow>("pow"); add_generic_op<hip_pow>("pow");
add_generic_op<hip_sqdiff>("sqdiff"); add_generic_op<hip_sqdiff>("sqdiff");
add_generic_op<hip_relu>("relu"); add_generic_op<hip_relu>("relu");
......
...@@ -2029,6 +2029,25 @@ TEST_CASE(sqdiff_test) ...@@ -2029,6 +2029,25 @@ TEST_CASE(sqdiff_test)
EXPECT(migraphx::verify_range(results_vector, gold)); EXPECT(migraphx::verify_range(results_vector, gold));
} }
TEST_CASE(round_test)
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {9}};
auto l = p.add_literal(migraphx::literal{s, {1.1, 1.5, 1.6, -1.1, -1.5, -1.6, 0.0, 2.0, -2.0}});
p.add_instruction(migraphx::op::round{}, l);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
for(auto v : results_vector)
{
std::cout << v << "\t";
}
std::cout << std::endl;
std::vector<float> gold = {1.0, 2.0, 2.0, -1.0, -2.0, -2.0, 0.0, 2.0, -2.0};
EXPECT(migraphx::verify_range(results_vector, gold));
}
TEST_CASE(op_capture) TEST_CASE(op_capture)
{ {
migraphx::program p; migraphx::program p;
...@@ -2062,6 +2081,6 @@ TEST_CASE(op_capture) ...@@ -2062,6 +2081,6 @@ TEST_CASE(op_capture)
res.visit([&](auto output) { vec.assign(output.begin(), output.end()); }); res.visit([&](auto output) { vec.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(vec, cap_vec)); EXPECT(migraphx::verify_range(vec, cap_vec));
}; }
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -22,7 +22,7 @@ struct eliminate_contiguous_target ...@@ -22,7 +22,7 @@ struct eliminate_contiguous_target
TEST_CASE(standard_op) TEST_CASE(standard_op)
{ {
migraphx::program p; migraphx::program p;
auto l = p.add_literal(get_2x2()); auto l = p.add_parameter("x", {migraphx::shape::float_type, {2, 2}});
auto t = p.add_instruction(migraphx::op::transpose{{1, 0}}, l); auto t = p.add_instruction(migraphx::op::transpose{{1, 0}}, l);
auto c = p.add_instruction(migraphx::op::contiguous{}, t); auto c = p.add_instruction(migraphx::op::contiguous{}, t);
p.add_instruction(pass_standard_op{}, c); p.add_instruction(pass_standard_op{}, c);
...@@ -31,18 +31,40 @@ TEST_CASE(standard_op) ...@@ -31,18 +31,40 @@ TEST_CASE(standard_op)
EXPECT(std::distance(p.begin(), p.end()) == count); EXPECT(std::distance(p.begin(), p.end()) == count);
} }
TEST_CASE(non_standard_op) TEST_CASE(standard_op_const)
{ {
migraphx::program p; migraphx::program p;
auto l = p.add_literal(get_2x2()); auto l = p.add_literal(get_2x2());
auto t = p.add_instruction(migraphx::op::transpose{{1, 0}}, l); auto t = p.add_instruction(migraphx::op::transpose{{1, 0}}, l);
auto c = p.add_instruction(migraphx::op::contiguous{}, t); auto c = p.add_instruction(migraphx::op::contiguous{}, t);
p.add_instruction(pass_standard_op{}, c);
p.compile(eliminate_contiguous_target{});
EXPECT(std::distance(p.begin(), p.end()) == 2);
}
TEST_CASE(non_standard_op)
{
migraphx::program p;
auto l = p.add_parameter("x", {migraphx::shape::float_type, {2, 2}});
auto t = p.add_instruction(migraphx::op::transpose{{1, 0}}, l);
auto c = p.add_instruction(migraphx::op::contiguous{}, t);
p.add_instruction(pass_op{}, c); p.add_instruction(pass_op{}, c);
auto count = std::distance(p.begin(), p.end()); auto count = std::distance(p.begin(), p.end());
p.compile(eliminate_contiguous_target{}); p.compile(eliminate_contiguous_target{});
EXPECT(std::distance(p.begin(), p.end()) == count); EXPECT(std::distance(p.begin(), p.end()) == count);
} }
TEST_CASE(non_standard_op_const)
{
migraphx::program p;
auto l = p.add_literal(get_2x2());
auto t = p.add_instruction(migraphx::op::transpose{{1, 0}}, l);
auto c = p.add_instruction(migraphx::op::contiguous{}, t);
p.add_instruction(pass_op{}, c);
p.compile(eliminate_contiguous_target{});
EXPECT(std::distance(p.begin(), p.end()) == 2);
}
TEST_CASE(transpose_gemm) TEST_CASE(transpose_gemm)
{ {
migraphx::program p; migraphx::program p;
...@@ -59,7 +81,7 @@ TEST_CASE(transpose_gemm) ...@@ -59,7 +81,7 @@ TEST_CASE(transpose_gemm)
TEST_CASE(transpose_standard_op) TEST_CASE(transpose_standard_op)
{ {
migraphx::program p; migraphx::program p;
auto l = p.add_literal(get_2x2()); auto l = p.add_parameter("x", {migraphx::shape::float_type, {2, 2}});
auto t = p.add_instruction(migraphx::op::transpose{{1, 0}}, l); auto t = p.add_instruction(migraphx::op::transpose{{1, 0}}, l);
auto c = p.add_instruction(migraphx::op::contiguous{}, t); auto c = p.add_instruction(migraphx::op::contiguous{}, t);
auto sn = p.add_instruction(migraphx::op::sin{}, c); auto sn = p.add_instruction(migraphx::op::sin{}, c);
...@@ -69,6 +91,18 @@ TEST_CASE(transpose_standard_op) ...@@ -69,6 +91,18 @@ TEST_CASE(transpose_standard_op)
EXPECT(std::distance(p.begin(), p.end()) == count); EXPECT(std::distance(p.begin(), p.end()) == count);
} }
TEST_CASE(transpose_standard_op_const)
{
migraphx::program p;
auto l = p.add_literal(get_2x2());
auto t = p.add_instruction(migraphx::op::transpose{{1, 0}}, l);
auto c = p.add_instruction(migraphx::op::contiguous{}, t);
auto sn = p.add_instruction(migraphx::op::sin{}, c);
p.add_instruction(pass_standard_op{}, sn);
p.compile(eliminate_contiguous_target{});
EXPECT(std::distance(p.begin(), p.end()) == 3);
}
TEST_CASE(no_packed_unary_op) TEST_CASE(no_packed_unary_op)
{ {
migraphx::program p; migraphx::program p;
......
...@@ -3846,6 +3846,18 @@ struct test_reduce_mean_half : verify_program<test_reduce_mean_half> ...@@ -3846,6 +3846,18 @@ struct test_reduce_mean_half : verify_program<test_reduce_mean_half>
}; };
}; };
struct test_round : verify_program<test_round>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {2, 3, 4, 6}};
auto param = p.add_parameter("x", s);
p.add_instruction(migraphx::op::round{}, param);
return p;
}
};
struct test_convert : verify_program<test_convert> struct test_convert : verify_program<test_convert>
{ {
migraphx::program create_program() const migraphx::program create_program() const
......
...@@ -1054,4 +1054,14 @@ TEST_CASE(unknown_test) ...@@ -1054,4 +1054,14 @@ TEST_CASE(unknown_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(round_test)
{
migraphx::program p;
auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::double_type, {10, 5}});
p.add_instruction(migraphx::op::round{}, input);
auto prog = migraphx::parse_onnx("round_test.onnx");
EXPECT(p == prog);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
 round-example:E
xy"Round
test_roundZ
x
 

b
y
 

B
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment