Commit 11e155c2 authored by Paul's avatar Paul
Browse files

Merge

parents 8a9c5bce aa7ff911
function(add_api_test TEST_NAME TEST_SRC TEST_DIR)
set(NAME test_api_${TEST_NAME})
add_executable(${NAME} EXCLUDE_FROM_ALL ${TEST_SRC})
......@@ -10,12 +9,18 @@ function(add_api_test TEST_NAME TEST_SRC TEST_DIR)
add_dependencies(check ${NAME})
endfunction()
add_api_test(array_base test_array_base.cpp ${TEST_ONNX_DIR})
add_api_test(assign test_assign.cpp ${TEST_ONNX_DIR})
add_api_test(custom_op test_custom_op.cpp ${TEST_ONNX_DIR})
add_api_test(compile_options test_compile_options.cpp ${TEST_ONNX_DIR})
add_api_test(lookup test_lookup.cpp ${TEST_ONNX_DIR})
add_api_test(module_construct test_module_construct.cpp ${TEST_ONNX_DIR})
add_api_test(ref test_cpu.cpp ${TEST_ONNX_DIR})
add_api_test(save_load test_save_load.cpp ${TEST_ONNX_DIR})
add_api_test(op test_op_construct.cpp ${TEST_ONNX_DIR})
add_api_test(tf_parser test_tf_parser.cpp ${TEST_TF_DIR})
# GPU-based tests
if(MIGRAPHX_ENABLE_GPU)
add_api_test(gpu test_gpu.cpp ${TEST_ONNX_DIR})
# GPU-based tests
target_link_libraries(test_api_gpu migraphx_gpu)
endif()
#include <migraphx/migraphx.hpp>
#include "test.hpp"
struct array2 : migraphx::array_base<array2>
{
std::vector<int> v;
array2() = default;
array2(std::initializer_list<int> x) : v(x) {}
std::size_t size() const { return v.size(); }
int operator[](std::size_t i) const { return v[i]; }
};
TEST_CASE(iterators)
{
array2 a = {1, 2, 3};
EXPECT(bool{std::equal(a.begin(), a.end(), a.v.begin())});
}
TEST_CASE(front_back)
{
array2 a = {1, 2, 3};
EXPECT(a.front() == 1);
EXPECT(a.back() == 3);
}
TEST_CASE(empty)
{
array2 a = {1, 2, 3};
EXPECT(not a.empty());
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
#include <migraphx/migraphx.h>
#include <migraphx/migraphx.hpp>
#include "test.hpp"
TEST_CASE(shape_assign)
{
auto s1_cpp = migraphx::shape{migraphx_shape_float_type, {1, 3}};
std::vector<size_t> lens{2, 3};
// handle ptr is const, workaround to construct shape using C API
migraphx_shape_t s2;
migraphx_shape_create(&s2, migraphx_shape_float_type, lens.data(), lens.size());
auto s2_cpp = migraphx::shape(s2, migraphx::own{});
CHECK(bool{s1_cpp != s2_cpp});
// use C++ API for assignment
s1_cpp.assign_to_handle(s2);
CHECK(bool{s1_cpp == s2_cpp});
auto s3_cpp = migraphx::shape{migraphx_shape_float_type, lens};
// use C API for assignment
migraphx_shape_assign_to(s2, s3_cpp.get_handle_ptr());
CHECK(bool{s2_cpp == s3_cpp});
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
#include <migraphx/migraphx.h>
#include <migraphx/migraphx.hpp>
#include "test.hpp"
struct simple_custom_op final : migraphx::experimental_custom_op_base
{
virtual std::string name() const override { return "simple_custom_op"; }
virtual migraphx::shape compute_shape(migraphx::shapes inputs) const override
{
return inputs.front();
}
};
TEST_CASE(register_custom_op)
{
simple_custom_op simple_op;
migraphx::register_experimental_custom_op(simple_op);
auto op = migraphx::operation("simple_custom_op");
EXPECT(op.name() == "simple_custom_op");
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
#include <numeric>
#include <hip/hip_runtime_api.h>
#include <migraphx/migraphx.h>
#include <migraphx/migraphx.hpp>
#include "test.hpp"
......@@ -25,6 +26,24 @@ TEST_CASE(load_and_run)
CHECK(bool{shapes_before.front() == outputs.front().get_shape()});
}
TEST_CASE(load_and_run_ctx)
{
auto p = migraphx::parse_onnx("conv_relu_maxpool_test.onnx");
migraphx::compile_options options;
options.set_offload_copy();
p.compile(migraphx::target("gpu"), options);
migraphx::program_parameters pp;
auto param_shapes = p.get_parameter_shapes();
for(auto&& name : param_shapes.names())
{
pp.add(name, migraphx::argument::generate(param_shapes[name]));
}
auto ctx = p.experimental_get_context();
EXPECT(ctx.get_queue<hipStream_t>() != nullptr);
p.eval(pp);
ctx.finish();
}
TEST_CASE(if_pl_test)
{
auto run_prog = [&](auto cond) {
......
#include <migraphx/migraphx.hpp>
#include "test.hpp"
template <class T>
std::false_type has_handle(migraphx::rank<0>, T)
{
return {};
}
template <class T>
auto has_handle(migraphx::rank<1>, T*) -> decltype(migraphx::as_handle<T>{}, std::true_type{})
{
return {};
}
TEST_CASE(shape)
{
static_assert(std::is_same<migraphx::as_handle<migraphx_shape>, migraphx::shape>{}, "Failed");
static_assert(std::is_same<migraphx::as_handle<migraphx_shape_t>, migraphx::shape>{}, "Failed");
static_assert(std::is_same<migraphx::as_handle<const_migraphx_shape_t>, migraphx::shape>{},
"Failed");
}
TEST_CASE(non_handle)
{
int i = 0;
EXPECT(bool{has_handle(migraphx::rank<1>{}, migraphx_shape_t{})});
EXPECT(bool{not has_handle(migraphx::rank<1>{}, &i)});
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
#include <numeric>
#include <migraphx/migraphx.h>
#include <migraphx/migraphx.hpp>
#include "test.hpp"
TEST_CASE(add_literals)
{
migraphx::program p;
migraphx::module m = p.get_main_module();
migraphx::shape param_shape{migraphx_shape_float_type, {3, 3}};
std::vector<float> x_values(9, 1);
auto x = m.add_literal(param_shape, x_values.data());
std::vector<float> y_values(9, -1);
auto y = m.add_literal(param_shape, y_values.data());
auto add_op = migraphx::operation("add");
auto r = m.add_instruction(add_op, {x, y});
m.add_return({r});
// run on ref target
p.compile(migraphx::target("ref"));
migraphx::program_parameters pp;
auto outputs = p.eval(pp);
auto output = outputs[0];
std::vector<float> expected(9, 0);
CHECK(bool(output == migraphx::argument(param_shape, expected.data())));
}
TEST_CASE(if_then_else_op)
{
migraphx::shape param_shape{migraphx_shape_float_type, {3, 3}};
migraphx::shape cond_s{migraphx_shape_bool_type};
auto create_program = [&]() {
migraphx::program p;
auto mm = p.get_main_module();
auto cond = mm.add_parameter("cond", cond_s);
auto x = mm.add_parameter("x", param_shape);
auto y = mm.add_parameter("y", param_shape);
auto then_mod = p.create_module("If_0_if");
auto x_identity = then_mod.add_instruction(migraphx::operation("identity"), {x});
then_mod.add_return({x_identity});
auto else_mod = p.create_module("If_0_else");
auto y_identity = else_mod.add_instruction(migraphx::operation("identity"), {y});
else_mod.add_return({y_identity});
auto if_ins = mm.add_instruction(migraphx::operation("if"), {cond}, {then_mod, else_mod});
auto get_tuple_op = migraphx::operation("get_tuple_elem", "{index: 0}");
auto ret = mm.add_instruction(get_tuple_op, {if_ins});
mm.add_return({ret});
return p;
};
std::vector<float> x_data(9, 1);
std::vector<float> y_data(9, -1);
auto x_arg = migraphx::argument(param_shape, x_data.data());
auto y_arg = migraphx::argument(param_shape, y_data.data());
auto run_prog = [&](bool cond) {
auto p = create_program();
p.compile(migraphx::target("ref"));
auto outputs =
p.eval({{"cond", migraphx::argument(cond_s, &cond)}, {"x", x_arg}, {"y", y_arg}});
return outputs[0];
};
// then branch
auto then_res = run_prog(true);
CHECK(bool{then_res == x_arg});
// else branch
auto else_res = run_prog(false);
CHECK(bool{else_res == y_arg});
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
......@@ -4,7 +4,7 @@
TEST_CASE(load_save_default)
{
std::string filename = "migraphx_api_load_save.dat";
std::string filename = "migraphx_api_load_save.mxr";
auto p1 = migraphx::parse_onnx("conv_relu_maxpool_test.onnx");
auto s1 = p1.get_output_shapes();
......
......@@ -135,4 +135,52 @@ TEST_CASE(two_transpose_gather)
EXPECT(m1 == m2);
}
TEST_CASE(standard_reshape)
{
migraphx::module m1;
{
auto data = m1.add_parameter("2x2", {migraphx::shape::float_type, {2, 3, 4, 5}});
auto add = m1.add_instruction(migraphx::make_op("add"), data, data);
auto r = m1.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 1, 12, 5}}}), add);
m1.add_return({r});
}
run_pass(m1);
migraphx::module m2;
{
auto data = m2.add_parameter("2x2", {migraphx::shape::float_type, {2, 3, 4, 5}});
auto add = m2.add_instruction(migraphx::make_op("add"), data, data);
auto ca = m2.add_instruction(migraphx::make_op("contiguous"), add);
auto r = m2.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 1, 12, 5}}}), ca);
m2.add_return({r});
}
EXPECT(m1 == m2);
}
TEST_CASE(dead_instruction)
{
migraphx::module m1;
{
auto data = m1.add_parameter("2x2", {migraphx::shape::float_type, {2, 3, 4, 5}});
m1.add_instruction(migraphx::make_op("transpose", {{"permutation", {2, 0, 1, 3}}}), data);
auto r = m1.add_instruction(migraphx::make_op("transpose", {{"permutation", {2, 0, 1, 3}}}),
data);
m1.add_return({r});
}
run_pass(m1);
migraphx::module m2;
{
auto data = m2.add_parameter("2x2", {migraphx::shape::float_type, {2, 3, 4, 5}});
m2.add_instruction(migraphx::make_op("transpose", {{"permutation", {2, 0, 1, 3}}}), data);
auto r = m2.add_instruction(migraphx::make_op("transpose", {{"permutation", {2, 0, 1, 3}}}),
data);
auto cr = m2.add_instruction(migraphx::make_op("contiguous"), r);
m2.add_return({cr});
}
EXPECT(m1 == m2);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
......@@ -180,6 +180,40 @@ TEST_CASE(duplicate_args3)
EXPECT(result == migraphx::literal{0});
}
TEST_CASE(reused_twice)
{
migraphx::program p;
auto* mm = p.get_main_module();
std::vector<size_t> dims = {1, 2, 2};
auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, dims});
auto y = mm->add_parameter("y", migraphx::shape{migraphx::shape::float_type, dims});
auto z = mm->add_parameter("z", migraphx::shape{migraphx::shape::float_type, dims});
auto add1 = mm->add_instruction(migraphx::make_op("add"), x, y);
auto add2 = mm->add_instruction(migraphx::make_op("add"), add1, z);
auto epsilon = mm->add_literal(1e-12f);
auto exponent = mm->add_literal(2.0f);
auto mean = mm->add_instruction(migraphx::make_op("reduce_mean", {{"axes", {2}}}), add2);
auto mean_mbcast =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", dims}}), mean);
auto sub = mm->add_instruction(migraphx::make_op("sub"), add2, mean_mbcast);
auto exponent_mbcast =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", dims}}), exponent);
auto pow = mm->add_instruction(migraphx::make_op("pow"), sub, exponent_mbcast);
auto var = mm->add_instruction(migraphx::make_op("reduce_mean", {{"axes", {2}}}), pow);
auto epsilon_mbcast = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {1, dims.at(1), 1}}}), epsilon);
auto add_epsilon = mm->add_instruction(migraphx::make_op("add"), var, epsilon_mbcast);
mm->add_instruction(migraphx::make_op("sqrt"), add_epsilon);
mm->add_instruction(migraphx::make_op("add"), x, y);
auto count = std::distance(mm->begin(), mm->end());
run_pass(p);
p.debug_print();
EXPECT(std::distance(mm->begin(), mm->end()) != count);
EXPECT(std::distance(mm->begin(), mm->end()) == 4);
}
TEST_CASE(unused_module)
{
migraphx::program p;
......
......@@ -6,6 +6,12 @@
#include <test.hpp>
void run_pass(migraphx::program& p)
{
migraphx::run_passes(
p, {migraphx::eliminate_common_subexpression{}, migraphx::dead_code_elimination{}});
}
void run_pass(migraphx::module& m)
{
migraphx::run_passes(
......@@ -142,4 +148,59 @@ TEST_CASE(cse_test_literal)
EXPECT(m1 == m2);
}
TEST_CASE(cse_test_submodule)
{
migraphx::shape si{migraphx::shape::int64_type};
migraphx::shape s{migraphx::shape::int64_type, {1}};
migraphx::shape sc{migraphx::shape::bool_type};
auto create_program = [&](bool remove_literal = false) {
migraphx::program p;
std::vector<bool> vc = {true};
std::vector<int64_t> vd = {3};
auto* mm = p.get_main_module();
auto in_cond = mm->add_parameter("ccond", sc);
auto in_val = mm->add_parameter("val", s);
auto b0 = mm->add_literal(migraphx::literal(sc, vc));
auto b1 = b0;
if(not(remove_literal))
b1 = mm->add_literal(migraphx::literal(sc, vc));
auto* body1 = p.create_module("loop_module1");
body1->add_parameter("#loop_module_in_1", sc);
auto in_v1 = body1->add_parameter("#loop_module_in_2", s);
auto l1 = body1->add_literal(migraphx::literal(si, vd));
auto ad1 = body1->add_instruction(migraphx::make_op("add"), l1, l1);
auto val1 = body1->add_instruction(migraphx::make_op("add"), in_v1, ad1);
auto cond1 = body1->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::bool_type}}), b0);
auto cond2 = body1->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::bool_type}}), b1);
body1->add_return({cond1, cond2, val1, val1});
auto* body2 = p.create_module("loop_module2");
body2->add_parameter("#loop_module_in_1", sc);
auto in_v2 = body2->add_parameter("#loop_module_in_2", s);
auto l2 = body2->add_literal(migraphx::literal(si, vd));
auto ad2 = body2->add_instruction(migraphx::make_op("add"), l2, l2);
auto val2 = body2->add_instruction(migraphx::make_op("add"), in_v2, ad2);
auto cond3 = body2->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::bool_type}}), b1);
body2->add_return({cond3, val2, val2});
auto loop1 = mm->add_instruction(
migraphx::make_op("loop", {{"max_iterations", 1}}), {in_cond, in_val}, {body1});
auto loop2 = mm->add_instruction(
migraphx::make_op("loop", {{"max_iterations", 1}}), {in_cond, in_val}, {body2});
mm->add_return({loop1, loop2});
return p;
};
auto p = create_program();
run_pass(p);
EXPECT(p == create_program(true));
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
......@@ -55,7 +55,8 @@ TEST_CASE(rewrite_pad)
auto l0 = create_im2col(padded_img, channels, m);
auto l1 = create_conv(padded_img, channels, m);
auto l2 = m.add_instruction(migraphx::make_op("pooling", {{"mode", "max"}}), padded_img);
auto l2 = m.add_instruction(
migraphx::make_op("pooling", {{"mode", migraphx::op::pooling_mode::max}}), padded_img);
m.add_instruction(migraphx::make_op("identity"), l0, l1, l2);
auto s0 = l0->get_shape();
......
......@@ -5,7 +5,7 @@
#include <migraphx/context.hpp>
#include "test.hpp"
TEST_CASE(gpu_context)
TEST_CASE(gpu_context_serialize)
{
migraphx::context ctx = migraphx::gpu::context{0, 3};
......@@ -25,4 +25,10 @@ TEST_CASE(gpu_context)
EXPECT(v == v1);
}
TEST_CASE(context_queue)
{
migraphx::context ctx = migraphx::gpu::context{0, 3};
EXPECT(ctx.get_queue().get<hipStream_t>() != nullptr);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
#include <test.hpp>
#include <migraphx/quantization.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/op/add.hpp>
#include <migraphx/op/mul.hpp>
#include <migraphx/op/multibroadcast.hpp>
#include <migraphx/op/pow.hpp>
#include <migraphx/op/tanh.hpp>
#include <migraphx/gpu/target.hpp>
#include <migraphx/instruction.hpp>
migraphx::program create_gelu()
{
migraphx::program p;
auto* mm = p.get_main_module();
std::vector<float> data0 = {0.044715};
std::vector<float> data1 = {0.797885};
std::vector<float> data2 = {3};
std::vector<float> data3 = {0.5};
migraphx::shape s0{migraphx::shape::float_type, {1}};
std::vector<size_t> x_dims{1, 1, 5};
auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, x_dims});
auto const_val = mm->add_literal(migraphx::literal{s0, data0});
auto sqrt_2_pi = mm->add_literal(migraphx::literal{s0, data1});
auto three_val = mm->add_literal(migraphx::literal{s0, data2});
auto half_val = mm->add_literal(migraphx::literal{s0, data3});
auto mbcast_3 = mm->add_instruction(migraphx::op::multibroadcast{x_dims}, three_val);
auto pow_op = mm->add_instruction(migraphx::op::pow{}, x, mbcast_3);
auto mbcast_const = mm->add_instruction(migraphx::op::multibroadcast{x_dims}, const_val);
auto mul_const = mm->add_instruction(migraphx::op::mul{}, mbcast_const, pow_op);
auto add_x = mm->add_instruction(migraphx::op::add{}, x, mul_const);
auto mbcast_sqrt_2_pi = mm->add_instruction(migraphx::op::multibroadcast{x_dims}, sqrt_2_pi);
auto mul_add_x = mm->add_instruction(migraphx::op::mul{}, mbcast_sqrt_2_pi, add_x);
auto tanh_op = mm->add_instruction(migraphx::op::tanh{}, mul_add_x);
auto mbcast_half = mm->add_instruction(migraphx::op::multibroadcast{x_dims}, half_val);
auto mul_half = mm->add_instruction(migraphx::op::mul{}, mbcast_half, tanh_op);
auto add_mul_half = mm->add_instruction(migraphx::op::add{}, mul_half, mbcast_half);
auto mul_x = mm->add_instruction(migraphx::op::mul{}, x, add_mul_half);
mm->add_return({mul_x});
return p;
}
TEST_CASE(enable_fast_gelu)
{
migraphx::program p = create_gelu();
p.compile(migraphx::gpu::target{});
CHECK(any_of(*p.get_main_module(), [&](auto&& i) { return i.name() == "gpu::gelu"; }));
}
TEST_CASE(disable_fast_gelu)
{
migraphx::program p = create_gelu();
migraphx::compile_options options;
options.fast_math = false;
p.compile(migraphx::gpu::target{}, options);
CHECK(any_of(*p.get_main_module(), [&](auto&& i) { return i.name() == "gpu::gelu_new"; }));
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
......@@ -3,6 +3,7 @@
#include <migraphx/make_op.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/program.hpp>
#include <migraphx/par_for.hpp>
#include <migraphx/gpu/kernel.hpp>
#include <migraphx/gpu/target.hpp>
#include <migraphx/gpu/hip.hpp>
......@@ -10,7 +11,7 @@
#include <migraphx/gpu/device_name.hpp>
#include <migraphx/gpu/compile_hip.hpp>
#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/compile_pointwise.hpp>
#include <migraphx/gpu/compiler.hpp>
// NOLINTNEXTLINE
const std::string write_2s = R"__migraphx__(
......@@ -109,6 +110,24 @@ int main() {}
)__migraphx__";
// NOLINTNEXTLINE
const std::string math_template = R"__migraphx__(
#include <migraphx/kernels/pointwise.hpp>
#include <migraphx/kernels/math.hpp>
extern "C" {
__global__ void kernel(${type}* p)
{
auto x = *p;
*p = migraphx::implicit_conversion(migraphx::${invoke});
}
}
int main() {}
)__migraphx__";
migraphx::src_file make_src_file(const std::string& name, const std::string& content)
{
return {name, std::make_pair(content.data(), content.data() + content.size())};
......@@ -230,7 +249,8 @@ TEST_CASE(compile_pointwise)
migraphx::shape input{migraphx::shape::float_type, {5, 2}};
migraphx::gpu::context ctx;
auto co = migraphx::gpu::compile_pointwise(ctx, {input, input}, "[](auto x) { return x + 1; }");
auto co = migraphx::gpu::compile_op(
"pointwise", ctx, {input, input}, {{"lambda", "[](auto x) { return x + 1; }"}});
migraphx::program p;
auto* mm = p.get_main_module();
......@@ -247,4 +267,66 @@ TEST_CASE(compile_pointwise)
EXPECT(result == output_literal.get_argument());
}
TEST_CASE(compile_math)
{
std::vector<std::string> math_invoke = {
// clang-format off
"abs(x)",
"acos(x)",
"acosh(x)",
"asin(x)",
"asinh(x)",
"atan(x)",
"atanh(x)",
"ceil(x)",
"cos(x)",
"cosh(x)",
"erf(x)",
"exp(x)",
"floor(x)",
"isnan(x)",
"log(x)",
"max(x, x)",
"min(x, x)",
"pow(x, 0)",
"pow(x, x)",
"round(x)",
"rsqrt(x)",
"sin(x)",
"sinh(x)",
"sqrt(x)",
"tan(x)",
"tanh(x)",
"where(true, x, x)",
// clang-format on
};
std::vector<std::string> data_types;
auto vec_sizes = {2, 4, 6};
for(auto&& t : migraphx::shape::types())
{
if(contains({migraphx::shape::bool_type, migraphx::shape::tuple_type}, t))
continue;
auto name = migraphx::shape::cpp_type(t);
if(t == migraphx::shape::half_type)
name.insert(0, "migraphx::");
data_types.push_back(name);
migraphx::transform(vec_sizes, std::back_inserter(data_types), [&](auto i) {
return "migraphx::vec<" + name + ", " + std::to_string(i) + ">";
});
}
migraphx::shape input{migraphx::shape::float_type, {5, 2}};
migraphx::gpu::hip_compile_options options;
options.global = 1024;
options.local = 1024;
options.inputs = {input};
options.output = input;
migraphx::par_for(math_invoke.size() * data_types.size(), 1, [&](auto i) {
const auto& t = data_types[i % data_types.size()];
const auto& invoke = math_invoke[i / data_types.size()];
auto src = migraphx::interpolate_string(math_template, {{"type", t}, {"invoke", invoke}});
auto co = migraphx::gpu::compile_hip_code_object(src, options);
(void)co;
});
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
......@@ -68,9 +68,9 @@ struct nop
{
static std::string as_string() { return ""; }
template <class T>
static decltype(auto) call(T&& x)
static auto call(T&& x)
{
return x;
return static_cast<T&&>(x);
}
};
......@@ -113,6 +113,33 @@ inline auto operator<<(Stream& s, const Range& v) -> decltype(stream_range(s, v.
return s;
}
template <class T>
const T& get_value(const T& x)
{
return x;
}
template <class T, class Operator = nop>
struct lhs_expression;
template <class T>
lhs_expression<T> make_lhs_expression(T&& lhs);
template <class T, class Operator>
lhs_expression<T, Operator> make_lhs_expression(T&& lhs, Operator);
// NOLINTNEXTLINE
#define TEST_EXPR_BINARY_OPERATOR(op, name) \
template <class V> \
auto operator op(const V& rhs2) const \
{ \
return make_expression(*this, rhs2, name{}); /* NOLINT */ \
}
// NOLINTNEXTLINE
#define TEST_EXPR_UNARY_OPERATOR(op, name) \
auto operator op() const { return make_lhs_expression(lhs, name{}); /* NOLINT */ }
template <class T, class U, class Operator>
struct expression
{
......@@ -125,7 +152,12 @@ struct expression
return s;
}
decltype(auto) value() const { return Operator::call(lhs, rhs); };
friend decltype(auto) get_value(const expression& e) { return e.value(); }
decltype(auto) value() const { return Operator::call(get_value(lhs), get_value(rhs)); };
TEST_FOREACH_UNARY_OPERATORS(TEST_EXPR_UNARY_OPERATOR)
TEST_FOREACH_BINARY_OPERATORS(TEST_EXPR_BINARY_OPERATOR)
};
// TODO: Remove rvalue references
......@@ -135,9 +167,6 @@ expression<T, U, Operator> make_expression(T&& rhs, U&& lhs, Operator)
return {std::forward<T>(rhs), std::forward<U>(lhs)};
}
template <class T, class Operator = nop>
struct lhs_expression;
// TODO: Remove rvalue reference
template <class T>
lhs_expression<T> make_lhs_expression(T&& lhs)
......@@ -166,22 +195,12 @@ struct lhs_expression
return s;
}
decltype(auto) value() const { return Operator::call(lhs); }
// NOLINTNEXTLINE
#define TEST_LHS_BINARY_OPERATOR(op, name) \
template <class U> \
auto operator op(const U& rhs) const \
{ \
return make_expression(lhs, rhs, name{}); /* NOLINT */ \
}
friend decltype(auto) get_value(const lhs_expression& e) { return e.value(); }
TEST_FOREACH_BINARY_OPERATORS(TEST_LHS_BINARY_OPERATOR)
decltype(auto) value() const { return Operator::call(get_value(lhs)); }
// NOLINTNEXTLINE
#define TEST_LHS_UNARY_OPERATOR(op, name) \
auto operator op() const { return make_lhs_expression(lhs, name{}); /* NOLINT */ }
TEST_FOREACH_UNARY_OPERATORS(TEST_LHS_UNARY_OPERATOR)
TEST_FOREACH_BINARY_OPERATORS(TEST_EXPR_BINARY_OPERATOR)
TEST_FOREACH_UNARY_OPERATORS(TEST_EXPR_UNARY_OPERATOR)
// NOLINTNEXTLINE
#define TEST_LHS_REOPERATOR(op) \
......@@ -197,7 +216,7 @@ struct lhs_expression
TEST_LHS_REOPERATOR(%)
TEST_LHS_REOPERATOR(&)
TEST_LHS_REOPERATOR(|)
TEST_LHS_REOPERATOR (^)
TEST_LHS_REOPERATOR(^)
};
template <class F>
......@@ -223,6 +242,13 @@ auto make_predicate(const std::string& msg, F f)
return make_lhs_expression(predicate<F>{msg, f}, function{});
}
inline std::string as_string(bool x)
{
if(x)
return "true";
return "false";
}
template <class T>
std::string as_string(const T& x)
{
......@@ -627,18 +653,21 @@ inline void run(int argc, const char* argv[])
} // namespace test
// NOLINTNEXTLINE
#define TEST_CAPTURE(...) test::capture{}->*__VA_ARGS__
// NOLINTNEXTLINE
#define CHECK(...) \
test::failed( \
test::capture{}->*__VA_ARGS__, #__VA_ARGS__, __PRETTY_FUNCTION__, __FILE__, __LINE__, [] { \
})
// NOLINTNEXTLINE
#define EXPECT(...) \
test::failed(test::capture{}->*__VA_ARGS__, \
#__VA_ARGS__, \
__PRETTY_FUNCTION__, \
__FILE__, \
__LINE__, \
#define EXPECT(...) \
test::failed(TEST_CAPTURE(__VA_ARGS__), \
#__VA_ARGS__, \
__PRETTY_FUNCTION__, \
__FILE__, \
__LINE__, \
&test::fail)
// NOLINTNEXTLINE
#define STATUS(...) EXPECT((__VA_ARGS__) == 0)
......
......@@ -55,7 +55,9 @@ TEST_CASE(rewrite_pad)
auto l0 = create_im2col(l_img, channels, m);
auto l1 = create_conv(l_img, channels, m);
auto l2 = m.add_instruction(
migraphx::make_op("pooling", {{"mode", "max"}, {"padding", {0, 0, 1, 1}}}), l_img);
migraphx::make_op("pooling",
{{"mode", migraphx::op::pooling_mode::max}, {"padding", {0, 0, 1, 1}}}),
l_img);
m.add_instruction(migraphx::make_op("identity"), l0, l1, l2);
run_pass(m);
......@@ -76,8 +78,10 @@ TEST_CASE(rewrite_pad_symmetric)
migraphx::shape s_img{migraphx::shape::int32_type, {1, channels, img_dim[0], img_dim[1]}};
auto l_img = m.add_literal(migraphx::literal{s_img, input});
m.add_instruction(migraphx::make_op("pooling", {{"mode", "max"}, {"padding", {1, 1, 1, 1}}}),
l_img);
m.add_instruction(
migraphx::make_op("pooling",
{{"mode", migraphx::op::pooling_mode::max}, {"padding", {1, 1, 1, 1}}}),
l_img);
run_pass(m);
EXPECT(std::none_of(
......
......@@ -332,7 +332,7 @@ TEST_CASE(match_either_args_any1)
match::name("sum")(match::either_arg(0, 1)(match::any().bind("x"), match::any().bind("y")));
auto r = find_match(mm, m);
EXPECT(bool{r.result == sum1});
EXPECT(bool{r.instructions.at("x") != r.instructions.at("y")});
EXPECT(bool{r.instructions["x"] != r.instructions["y"]});
}
TEST_CASE(match_either_args_any2)
......@@ -347,7 +347,7 @@ TEST_CASE(match_either_args_any2)
match::either_arg(0, 1)(match::any().bind("x"), match::name("@literal").bind("y")));
auto r = find_match(mm, m);
EXPECT(bool{r.result == sum1});
EXPECT(bool{r.instructions.at("x") != r.instructions.at("y")});
EXPECT(bool{r.instructions["x"] != r.instructions["y"]});
}
TEST_CASE(match_either_args_any3)
......@@ -362,7 +362,7 @@ TEST_CASE(match_either_args_any3)
match::either_arg(0, 1)(match::name("@literal").bind("x"), match::any().bind("y")));
auto r = find_match(mm, m);
EXPECT(bool{r.result == sum1});
EXPECT(bool{r.instructions.at("x") != r.instructions.at("y")});
EXPECT(bool{r.instructions["x"] != r.instructions["y"]});
}
TEST_CASE(match_either_args_any4)
......@@ -377,7 +377,7 @@ TEST_CASE(match_either_args_any4)
match::either_arg(0, 1)(match::name("sum").bind("x"), match::any().bind("y")));
auto r = find_match(mm, m);
EXPECT(bool{r.result == sum2});
EXPECT(bool{r.instructions.at("x") != r.instructions.at("y")});
EXPECT(bool{r.instructions["x"] != r.instructions["y"]});
}
TEST_CASE(match_either_args_any5)
......@@ -392,7 +392,7 @@ TEST_CASE(match_either_args_any5)
match::either_arg(0, 1)(match::any().bind("x"), match::name("sum").bind("y")));
auto r = find_match(mm, m);
EXPECT(bool{r.result == sum2});
EXPECT(bool{r.instructions.at("x") != r.instructions.at("y")});
EXPECT(bool{r.instructions["x"] != r.instructions["y"]});
}
TEST_CASE(match_all_of1)
......@@ -747,10 +747,10 @@ TEST_CASE(match_bind1)
match::standard_shape())
.bind("pass");
auto r = find_match(mm, m);
EXPECT(bool{r.instructions.at("one") == one});
EXPECT(bool{r.instructions.at("two") == two});
EXPECT(bool{r.instructions.at("sum") == sum});
EXPECT(bool{r.instructions.at("pass") == pass});
EXPECT(bool{r.instructions["one"] == one});
EXPECT(bool{r.instructions["two"] == two});
EXPECT(bool{r.instructions["sum"] == sum});
EXPECT(bool{r.instructions["pass"] == pass});
EXPECT(bool{r.result == pass});
}
......@@ -795,9 +795,9 @@ TEST_CASE(match_bind_modules2)
match::standard_shape())
.bind("pass");
auto r = find_match(*child, m);
EXPECT(bool{r.instructions.at("two") == two});
EXPECT(bool{r.instructions.at("sum") == sum});
EXPECT(bool{r.instructions.at("pass") == pass});
EXPECT(bool{r.instructions["two"] == two});
EXPECT(bool{r.instructions["sum"] == sum});
EXPECT(bool{r.instructions["pass"] == pass});
EXPECT(bool{r.result == pass});
}
......
celu_alpha_test:R

xy"Celu*
alphaL?celu_alpha_testZ
x

b
y

B
\ No newline at end of file
celu_default_test:K
xy"Celucelu_default_testZ
x


b
y


B
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment