"megatron/vscode:/vscode.git/clone" did not exist on "a71692976178dade8ab42686191150be4a7a3d7b"
Commit 00d5d880 authored by Khalique Ahmed's avatar Khalique Ahmed
Browse files

Merge branch 'develop' of https://github.com/ROCmSoftwarePlatform/AMDMIGraphX into mi100_opts

parents 00d90ca8 f60c3815
......@@ -57,20 +57,7 @@ struct parse_pooling : op_parser<parse_pooling>
calculate_padding(0, pads, input_dims[2], op.stride[0], 1, op.lengths[0]);
calculate_padding(1, pads, input_dims[3], op.stride[1], 1, op.lengths[1]);
if(pads[0] != pads[2] || pads[1] != pads[3])
{
std::vector<int64_t> padding = {0, 0, pads[0], pads[1], 0, 0, pads[2], pads[3]};
l0 = info.add_instruction(
migraphx::make_op(
"pad",
{{"pads", padding}, {"value", std::numeric_limits<float>::lowest()}}),
l0);
}
else
{
op.padding[0] = pads[0];
op.padding[1] = pads[1];
}
op.padding = std::vector<size_t>(pads.begin(), pads.end());
}
}
return info.add_instruction(op, l0);
......
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/instruction.hpp>
#include <basic_ops.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/ranges.hpp>
......@@ -113,15 +114,16 @@ TEST_CASE(depth_test)
TEST_CASE(undefined_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto undef = mm->add_instruction(migraphx::make_op("undefined"));
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
mm->add_instruction(migraphx::make_op("undefined"));
mm->add_instruction(sum_op{}, one, two);
auto count = std::distance(mm->begin(), mm->end());
run_pass(p);
EXPECT(std::distance(mm->begin(), mm->end()) == count - 1);
EXPECT(not mm->has_instruction(undef));
EXPECT(
std::none_of(mm->begin(), mm->end(), [](auto&& ins) { return ins.name() == "undefined"; }));
auto result = p.eval({}).back();
EXPECT(result == migraphx::literal{3});
EXPECT(result != migraphx::literal{4});
......
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/normalize_ops.hpp>
#include <migraphx/eliminate_pad.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/instruction.hpp>
......@@ -10,7 +11,9 @@
void run_pass(migraphx::module& m)
{
migraphx::run_passes(m, {migraphx::eliminate_pad{}, migraphx::dead_code_elimination{}});
migraphx::run_passes(
m,
{migraphx::normalize_ops{}, migraphx::eliminate_pad{}, migraphx::dead_code_elimination{}});
}
migraphx::instruction_ref
......@@ -66,15 +69,15 @@ TEST_CASE(rewrite_pad)
auto om1 = l1->get_operator().to_value();
auto om2 = l2->get_operator().to_value();
EXPECT(op0["padding"].to_vector<std::size_t>() == std::vector<std::size_t>{1, 1});
EXPECT(om1["padding"].to_vector<std::size_t>() == std::vector<std::size_t>{1, 1});
EXPECT(om2["padding"].to_vector<std::size_t>() == std::vector<std::size_t>{1, 1});
EXPECT(op0["padding"].to_vector<std::size_t>() == std::vector<std::size_t>{1, 1, 1, 1});
EXPECT(om1["padding"].to_vector<std::size_t>() == std::vector<std::size_t>{1, 1, 1, 1});
EXPECT(om2["padding"].to_vector<std::size_t>() == std::vector<std::size_t>{1, 1, 1, 1});
EXPECT(std::none_of(
m.begin(), m.end(), [](const migraphx::instruction& ins) { return ins.name() == "pad"; }));
}
TEST_CASE(rewrite_pad_im2col_asymetric)
TEST_CASE(rewrite_pad_im2col_asymmetric)
{
migraphx::module m;
......@@ -95,10 +98,10 @@ TEST_CASE(rewrite_pad_im2col_asymetric)
EXPECT(l0->get_shape() == s0);
auto op0 = l0->get_operator().to_value();
EXPECT(op0["padding"].to_vector<std::size_t>() == std::vector<std::size_t>{0, 0});
EXPECT(op0["padding"].to_vector<std::size_t>() == std::vector<std::size_t>{0, 0, 2, 2});
run_pass(m);
EXPECT(std::any_of(
EXPECT(std::none_of(
m.begin(), m.end(), [](const migraphx::instruction& ins) { return ins.name() == "pad"; }));
}
......
#include "migraphx/instruction_ref.hpp"
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/lowering.hpp>
#include <migraphx/gpu/target.hpp>
#include <migraphx/adjust_allocation.hpp>
#include <migraphx/gpu/pack_int8_args.hpp>
#include <migraphx/gpu/rocblas.hpp>
#include <migraphx/auto_contiguous.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_contiguous.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/make_op.hpp>
#include <test.hpp>
void run_passes(migraphx::module& m)
{
auto ctx = migraphx::gpu::context{};
migraphx::run_passes(m,
{migraphx::auto_contiguous{},
migraphx::gpu::lowering{&ctx, false},
migraphx::dead_code_elimination{},
migraphx::gpu::pack_int8_args{},
migraphx::dead_code_elimination{}});
}
bool get_int8_x4_format()
{
bool int8_x4_format = true;
#if ROCBLAS_VERSION_MAJOR >= 2 && ROCBLAS_VERSION_MINOR >= 38
auto ctx = migraphx::gpu::context{};
rocblas_gemm_flags flag;
rocblas_query_int8_layout_flag(ctx.get_stream().get_rocblas(), &flag);
int8_x4_format = (flag == rocblas_gemm_flags_pack_int8x4);
#endif
return int8_x4_format;
}
TEST_CASE(quant_dot)
{
auto create_module = [] {
migraphx::module m("test");
migraphx::shape m1_shape{migraphx::shape::int8_type, {5, 8}};
migraphx::shape m2_shape{migraphx::shape::int8_type, {8, 7}};
migraphx::shape m3_shape{migraphx::shape::int32_type, {5, 7}};
auto l1 = m.add_parameter("a", m1_shape);
auto l2 = m.add_parameter("b", m2_shape);
auto l3 = m.add_parameter("c", m3_shape);
auto r = m.add_instruction(migraphx::make_op("quant_dot"), l1, l2, l3);
m.add_return({r});
return m;
};
auto create_optimized_int8_x4 = [](bool int8_x4) {
migraphx::module m("test");
migraphx::shape m1_shape{migraphx::shape::int8_type, {5, 8}};
migraphx::shape m2_shape{migraphx::shape::int8_type, {8, 7}};
migraphx::shape m3_shape{migraphx::shape::int32_type, {5, 7}};
auto l1 = m.add_parameter("a", m1_shape);
auto l2 = m.add_parameter("b", m2_shape);
auto l3 = m.add_parameter("c", m3_shape);
auto output = m.add_parameter("test:#output_0", m3_shape);
auto cout = m.add_instruction(migraphx::make_op("hip::copy"), l3, output);
auto packa = l2;
if(int8_x4)
{
auto alloc = m.add_instruction(
migraphx::make_op("hip::allocate", {{"shape", migraphx::to_value(m2_shape)}}));
packa = m.add_instruction(migraphx::make_op("gpu::int8_gemm_pack_a"), l2, alloc);
}
auto gemm = m.add_instruction(
migraphx::make_op("gpu::quant_gemm",
{{"alpha", 1}, {"beta", 1}, {"int8_x4_format", int8_x4}}),
l1,
packa,
cout,
cout);
m.add_return({gemm});
return m;
};
auto m1 = create_module();
run_passes(m1);
bool flag = get_int8_x4_format();
auto m2 = create_optimized_int8_x4(flag);
EXPECT(m1 == m2);
}
TEST_CASE(quant_dot_trans)
{
auto create_module = [] {
migraphx::module m("test");
migraphx::shape s1{migraphx::shape::int8_type, {3, 2, 8, 5}};
migraphx::shape s2{migraphx::shape::int8_type, {3, 2, 7, 8}};
auto l1 = m.add_parameter("a", s1);
auto tl1 = m.add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), l1);
auto l2 = m.add_parameter("b", s2);
auto tl2 = m.add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), l2);
auto r = m.add_instruction(
migraphx::make_op("quant_dot", {{"alpha", 3}, {"beta", 2}}), tl1, tl2);
m.add_return({r});
return m;
};
auto create_optimized_int8_x4 = [](bool int8_x4) {
migraphx::module m("test");
migraphx::shape s1{migraphx::shape::int8_type, {3, 2, 8, 5}};
migraphx::shape s2{migraphx::shape::int8_type, {3, 2, 7, 8}};
migraphx::shape s3{migraphx::shape::int32_type, {3, 2, 5, 7}};
auto l1 = m.add_parameter("a", s1);
auto l2 = m.add_parameter("b", s2);
auto output = m.add_parameter("test:#output_0", s3);
auto tl1 = m.add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), l1);
migraphx::shape ts1{migraphx::shape::int8_type, {3, 2, 5, 8}};
auto alloca = m.add_instruction(
migraphx::make_op("hip::allocate", {{"shape", migraphx::to_value(ts1)}}));
auto conta = m.add_instruction(migraphx::make_op("gpu::contiguous"), tl1, alloca);
auto tl2 = m.add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), l2);
migraphx::shape ts2{migraphx::shape::int8_type, {3, 2, 8, 7}};
auto allocb = m.add_instruction(
migraphx::make_op("hip::allocate", {{"shape", migraphx::to_value(ts2)}}));
auto contb = m.add_instruction(migraphx::make_op("gpu::contiguous"), tl2, allocb);
auto packb = contb;
if(int8_x4)
{
auto allocpb = m.add_instruction(
migraphx::make_op("hip::allocate", {{"shape", migraphx::to_value(ts2)}}));
packb = m.add_instruction(migraphx::make_op("gpu::int8_gemm_pack_a"), contb, allocpb);
}
auto gemm = m.add_instruction(
migraphx::make_op("gpu::quant_gemm",
{{"alpha", 3}, {"beta", 0}, {"int8_x4_format", int8_x4}}),
conta,
packb,
output);
m.add_return({gemm});
return m;
};
auto m1 = create_module();
bool flag = get_int8_x4_format();
auto m2 = create_optimized_int8_x4(flag);
run_passes(m1);
EXPECT(m1 == m2);
}
TEST_CASE(quant_dot_pad)
{
auto create_module = [] {
migraphx::module m("test");
migraphx::shape s1{migraphx::shape::int8_type, {5, 6}};
migraphx::shape s2{migraphx::shape::int8_type, {6, 7}};
migraphx::shape s3{migraphx::shape::int32_type, {5, 7}};
auto l1 = m.add_parameter("a", s1);
auto l2 = m.add_parameter("b", s2);
auto l3 = m.add_parameter("c", s3);
auto r = m.add_instruction(migraphx::make_op("quant_dot"), l1, l2, l3);
m.add_return({r});
return m;
};
auto create_optimized_int8_x4 = [](bool int8_x4) {
migraphx::module m("test");
migraphx::shape s1{migraphx::shape::int8_type, {5, 6}};
migraphx::shape ps1{migraphx::shape::int8_type, {5, 8}};
migraphx::shape s2{migraphx::shape::int8_type, {6, 7}};
migraphx::shape ps2{migraphx::shape::int8_type, {8, 7}};
migraphx::shape s3{migraphx::shape::int32_type, {5, 7}};
auto l1 = m.add_parameter("a", s1);
auto l2 = m.add_parameter("b", s2);
auto l3 = m.add_parameter("c", s3);
auto output = m.add_parameter("test:#output_0", s3);
auto pl1 = l1;
auto packa = l2;
migraphx::instruction_ref pl2{};
if(int8_x4)
{
auto po1 = m.insert_instruction(
l1, migraphx::make_op("hip::allocate", {{"shape", migraphx::to_value(ps1)}}));
pl1 = m.add_instruction(
migraphx::make_op("gpu::pad", {{"mode", 0}, {"pads", {0, 2, 0, 0}}, {"value", 0}}),
l1,
po1);
auto po2 = m.insert_instruction(
l2, migraphx::make_op("hip::allocate", {{"shape", migraphx::to_value(ps2)}}));
pl2 = m.insert_instruction(
std::next(l2),
migraphx::make_op("gpu::pad", {{"mode", 0}, {"pads", {2, 0, 0, 0}}, {"value", 0}}),
l2,
po2);
}
auto cout = m.add_instruction(migraphx::make_op("hip::copy"), l3, output);
if(int8_x4)
{
auto alloc = m.add_instruction(
migraphx::make_op("hip::allocate", {{"shape", migraphx::to_value(ps2)}}));
packa = m.add_instruction(migraphx::make_op("gpu::int8_gemm_pack_a"), pl2, alloc);
}
auto gemm = m.add_instruction(
migraphx::make_op("gpu::quant_gemm",
{{"alpha", 1}, {"beta", 1}, {"int8_x4_format", int8_x4}}),
pl1,
packa,
cout,
cout);
m.add_return({gemm});
return m;
};
auto m1 = create_module();
bool flag = get_int8_x4_format();
auto m2 = create_optimized_int8_x4(flag);
run_passes(m1);
EXPECT(m1 == m2);
}
TEST_CASE(quant_dot_trans_pad)
{
auto create_module = [] {
migraphx::module m("test");
migraphx::shape s1{migraphx::shape::int8_type, {3, 2, 9, 5}};
migraphx::shape s2{migraphx::shape::int8_type, {3, 2, 7, 9}};
auto l1 = m.add_parameter("a", s1);
auto tl1 = m.add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), l1);
auto l2 = m.add_parameter("b", s2);
auto tl2 = m.add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), l2);
auto r = m.add_instruction(
migraphx::make_op("quant_dot", {{"alpha", 3}, {"beta", 2}}), tl1, tl2);
m.add_return({r});
return m;
};
auto create_optimized_int8_x4 = [](bool int8_x4) {
migraphx::module m("test");
migraphx::shape s1{migraphx::shape::int8_type, {3, 2, 9, 5}};
migraphx::shape ps1{migraphx::shape::int8_type, {3, 2, 5, 12}};
migraphx::shape s2{migraphx::shape::int8_type, {3, 2, 7, 9}};
migraphx::shape ps2{migraphx::shape::int8_type, {3, 2, 12, 7}};
migraphx::shape s3{migraphx::shape::int32_type, {3, 2, 5, 7}};
auto l1 = m.add_parameter("a", s1);
auto l2 = m.add_parameter("b", s2);
auto output = m.add_parameter("test:#output_0", s3);
auto tl1 = m.add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), l1);
migraphx::shape ts1{migraphx::shape::int8_type, {3, 2, 5, 9}};
auto ta = m.add_instruction(
migraphx::make_op("hip::allocate", {{"shape", migraphx::to_value(ts1)}}));
migraphx::instruction_ref pta{};
if(int8_x4)
{
pta = m.add_instruction(
migraphx::make_op("hip::allocate", {{"shape", migraphx::to_value(ps1)}}));
}
auto conta = m.add_instruction(migraphx::make_op("gpu::contiguous"), tl1, ta);
auto pa = conta;
if(int8_x4)
{
pa = m.add_instruction(
migraphx::make_op("gpu::pad", {{"mode", 0}, {"pads", {0, 0, 0, 3, 0, 0, 0, 0}}}),
conta,
pta);
}
auto tl2 = m.add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), l2);
migraphx::shape ts2{migraphx::shape::int8_type, {3, 2, 9, 7}};
auto tb = m.add_instruction(
migraphx::make_op("hip::allocate", {{"shape", migraphx::to_value(ts2)}}));
migraphx::instruction_ref ptb{};
if(int8_x4)
{
ptb = m.add_instruction(
migraphx::make_op("hip::allocate", {{"shape", migraphx::to_value(ps2)}}));
}
auto contb = m.add_instruction(migraphx::make_op("gpu::contiguous"), tl2, tb);
auto packb = contb;
if(int8_x4)
{
auto pb = m.add_instruction(
migraphx::make_op("gpu::pad", {{"mode", 0}, {"pads", {0, 0, 3, 0, 0, 0, 0, 0}}}),
contb,
ptb);
auto allocpb = m.add_instruction(
migraphx::make_op("hip::allocate", {{"shape", migraphx::to_value(ps2)}}));
packb = m.add_instruction(migraphx::make_op("gpu::int8_gemm_pack_a"), pb, allocpb);
}
auto gemm = m.add_instruction(
migraphx::make_op("gpu::quant_gemm",
{{"alpha", 3}, {"beta", 0}, {"int8_x4_format", int8_x4}}),
pa,
packb,
output);
m.add_return({gemm});
return m;
};
auto m1 = create_module();
bool flag = get_int8_x4_format();
auto m2 = create_optimized_int8_x4(flag);
run_passes(m1);
EXPECT(m1 == m2);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
......@@ -63,9 +63,9 @@ TEST_CASE(int8_quantization)
auto create_program = [] {
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sa{migraphx::shape::float_type, {2, 16}};
migraphx::shape sa{migraphx::shape::float_type, {5, 16}};
migraphx::shape sb{migraphx::shape::float_type, {16, 8}};
migraphx::shape sc{migraphx::shape::float_type, {2, 8}};
migraphx::shape sc{migraphx::shape::float_type, {5, 8}};
auto pa = mm->add_parameter("a", sa);
auto pb = mm->add_parameter("b", sb);
auto pc = mm->add_parameter("c", sc);
......@@ -77,9 +77,9 @@ TEST_CASE(int8_quantization)
{
auto p = create_program();
migraphx::parameter_map m;
migraphx::shape sa{migraphx::shape::float_type, {2, 16}};
migraphx::shape sa{migraphx::shape::float_type, {5, 16}};
migraphx::shape sb{migraphx::shape::float_type, {16, 8}};
migraphx::shape sc{migraphx::shape::float_type, {2, 8}};
migraphx::shape sc{migraphx::shape::float_type, {5, 8}};
m["a"] = migraphx::generate_argument(sa);
m["b"] = migraphx::generate_argument(sb);
m["c"] = migraphx::generate_argument(sc);
......
......@@ -9,6 +9,10 @@
#include <unordered_map>
#include <vector>
#ifdef __linux__
#include <unistd.h>
#endif
#ifndef MIGRAPHX_GUARD_TEST_TEST_HPP
#define MIGRAPHX_GUARD_TEST_TEST_HPP
......@@ -264,6 +268,32 @@ struct capture
}
};
enum class color
{
reset = 0,
bold = 1,
underlined = 4,
fg_red = 31,
fg_green = 32,
fg_yellow = 33,
fg_blue = 34,
fg_default = 39,
bg_red = 41,
bg_green = 42,
bg_yellow = 43,
bg_blue = 44,
bg_default = 49
};
inline std::ostream& operator<<(std::ostream& os, const color& c)
{
#ifndef _WIN32
static const bool use_color = isatty(STDOUT_FILENO) != 0;
if(use_color)
return os << "\033[" << static_cast<std::size_t>(c) << "m";
#endif
return os;
}
template <class T, class F>
void failed(T x, const char* msg, const char* func, const char* file, int line, F f)
{
......@@ -271,7 +301,7 @@ void failed(T x, const char* msg, const char* func, const char* file, int line,
{
std::cout << func << std::endl;
std::cout << file << ":" << line << ":" << std::endl;
std::cout << " FAILED: " << msg << " "
std::cout << color::bold << color::fg_red << " FAILED: " << color::reset << msg << " "
<< "[ " << x << " ]" << std::endl;
f();
}
......@@ -315,7 +345,7 @@ auto near(T px, U py, double ptol = 1e-6f)
using string_map = std::unordered_map<std::string, std::vector<std::string>>;
template <class Keyword>
string_map parse(std::vector<std::string> as, Keyword keyword)
string_map generic_parse(std::vector<std::string> as, Keyword keyword)
{
string_map result;
......@@ -331,19 +361,22 @@ string_map parse(std::vector<std::string> as, Keyword keyword)
{
flag = f.front();
result[flag]; // Ensure the flag exists
flag = f.back();
}
}
return result;
}
using test_case = std::function<void()>;
inline auto& get_test_cases()
{
// NOLINTNEXTLINE
static std::vector<std::pair<std::string, std::function<void()>>> cases;
static std::vector<std::pair<std::string, test_case>> cases;
return cases;
}
inline void add_test_case(std::string name, std::function<void()> f)
inline void add_test_case(std::string name, test_case f)
{
get_test_cases().emplace_back(std::move(name), std::move(f));
}
......@@ -357,37 +390,243 @@ struct auto_register_test_case
}
};
inline void run_test_case(const std::string& name, const std::function<void()>& f)
struct failure_error
{
std::cout << "[ RUN ] " << name << std::endl;
f();
std::cout << "[ COMPLETE ] " << name << std::endl;
}
};
inline void run(int argc, const char* argv[])
[[noreturn]] inline void fail() { throw failure_error{}; }
struct driver
{
std::vector<std::string> as(argv + 1, argv + argc);
driver()
{
add_flag({"--help", "-h"}, "Show help");
add_flag({"--list", "-l"}, "List all test cases");
add_flag({"--continue", "-c"}, "Continue after failure");
add_flag({"--quiet", "-q"}, "Don't print out extra output");
}
struct argument
{
std::vector<std::string> flags = {};
std::string help = "";
int nargs = 1;
};
auto args = parse(as, [](auto &&) -> std::vector<std::string> { return {}; });
auto cases = args[""];
if(cases.empty())
void add_arg(const std::vector<std::string>& flags, const std::string& help = "")
{
for(auto&& tc : get_test_cases())
run_test_case(tc.first, tc.second);
arguments.push_back(argument{flags, help, 1});
}
else
void add_flag(const std::vector<std::string>& flags, const std::string& help = "")
{
std::unordered_map<std::string, std::function<void()>> m(get_test_cases().begin(),
get_test_cases().end());
for(auto&& name : cases)
arguments.push_back(argument{flags, help, 0});
}
void show_help(const std::string& exe) const
{
std::cout << std::endl;
std::cout << color::fg_yellow << "USAGE:" << color::reset << std::endl;
std::cout << " ";
std::cout << exe << " <test-case>... <options>" << std::endl;
std::cout << std::endl;
std::cout << color::fg_yellow << "ARGS:" << color::reset << std::endl;
std::cout << " ";
std::cout << color::fg_green << "<test-case>..." << color::reset;
std::cout << std::endl;
std::cout << " "
<< "Test case name to run" << std::endl;
std::cout << std::endl;
std::cout << color::fg_yellow << "OPTIONS:" << color::reset << std::endl;
for(auto&& arg : arguments)
{
std::string prefix = " ";
std::cout << color::fg_green;
for(const std::string& a : arg.flags)
{
std::cout << prefix;
std::cout << a;
prefix = ", ";
}
std::cout << color::reset << std::endl;
std::cout << " " << arg.help << std::endl;
}
}
std::ostream& out() const
{
struct null_buffer : std::streambuf
{
virtual int overflow(int c) override { return c; }
};
static null_buffer buffer;
static std::ostream null_stream(&buffer);
if(quiet)
return null_stream;
return std::cout;
}
string_map parse(int argc, const char* argv[]) const
{
std::vector<std::string> args(argv + 1, argv + argc);
string_map keys;
for(auto&& arg : arguments)
{
auto f = m.find(name);
if(f == m.end())
std::cout << "[ ERROR ] Test case '" << name << "' not found." << std::endl;
for(auto&& flag : arg.flags)
{
keys[flag] = {arg.flags.front()};
if(arg.nargs == 0)
keys[flag].push_back("");
}
}
auto result = generic_parse(args, [&](auto&& s) -> std::vector<std::string> {
if(keys.count(s) > 0)
return keys[s];
else
run_test_case(name, f->second);
return {};
});
result["__exe__"].push_back(argv[0]);
return result;
}
static std::string create_command(const string_map& args)
{
std::stringstream ss;
ss << args.at("__exe__").front();
if(args.count("") > 0)
{
for(auto&& arg : args.at(""))
ss << " \"" << arg << "\"";
}
for(auto&& p : args)
{
if(p.first == "__exe__")
continue;
if(p.first.empty())
continue;
ss << " " << p.first;
for(auto&& arg : p.second)
ss << " \"" << arg << "\"";
}
return ss.str();
}
static std::string fork(const std::string& name, string_map args)
{
std::string msg;
args[""] = {name};
args.erase("--continue");
args["--quiet"];
auto cmd = create_command(args);
auto r = std::system(cmd.c_str()); // NOLINT
if(r != 0)
msg = "Exited with " + std::to_string(r);
return msg;
}
void run_test_case(const std::string& name, const test_case& f, const string_map& args)
{
ran++;
out() << color::fg_green << "[ RUN ] " << color::reset << color::bold << name
<< color::reset << std::endl;
std::string msg;
if(args.count("--continue") > 0)
{
msg = fork(name, args);
}
else
{
try
{
f();
}
catch(const failure_error&)
{
msg = "Test failure";
}
}
if(msg.empty())
{
out() << color::fg_green << "[ COMPLETE ] " << color::reset << color::bold << name
<< color::reset << std::endl;
}
else
{
failed.push_back(name);
out() << color::fg_red << "[ FAILED ] " << color::reset << color::bold << name
<< color::reset << ": " << color::fg_yellow << msg << color::reset << std::endl;
}
}
void run(int argc, const char* argv[])
{
auto args = parse(argc, argv);
if(args.count("--help") > 0)
{
show_help(args.at("__exe__").front());
return;
}
if(args.count("--list") > 0)
{
for(auto&& tc : get_test_cases())
out() << tc.first << std::endl;
return;
}
if(args.count("--quiet") > 0)
quiet = true;
auto cases = args[""];
if(cases.empty())
{
for(auto&& tc : get_test_cases())
run_test_case(tc.first, tc.second, args);
}
else
{
std::unordered_map<std::string, test_case> m(get_test_cases().begin(),
get_test_cases().end());
for(auto&& iname : cases)
{
for(auto&& name : get_case_names(iname))
{
auto f = m.find(name);
if(f == m.end())
{
out() << color::fg_red << "[ ERROR ] Test case '" << name
<< "' not found." << color::reset << std::endl;
failed.push_back(name);
}
else
run_test_case(name, f->second, args);
}
}
}
out() << color::fg_green << "[==========] " << color::fg_yellow << ran << " tests ran"
<< color::reset << std::endl;
if(not failed.empty())
{
out() << color::fg_red << "[ FAILED ] " << color::fg_yellow << failed.size()
<< " tests failed" << color::reset << std::endl;
for(auto&& name : failed)
out() << color::fg_red << "[ FAILED ] " << color::fg_yellow << name
<< color::reset << std::endl;
std::exit(1);
}
}
std::function<std::vector<std::string>(const std::string&)> get_case_names =
[](const std::string& name) -> std::vector<std::string> { return {name}; };
std::vector<argument> arguments = {};
std::vector<std::string> failed = {};
std::size_t ran = 0;
bool quiet = false;
};
inline void run(int argc, const char* argv[])
{
driver d{};
d.run(argc, argv);
}
} // namespace test
......@@ -404,7 +643,7 @@ inline void run(int argc, const char* argv[])
__PRETTY_FUNCTION__, \
__FILE__, \
__LINE__, \
&std::abort)
&test::fail)
// NOLINTNEXTLINE
#define STATUS(...) EXPECT((__VA_ARGS__) == 0)
......
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/inline_module.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/instruction.hpp>
#include <basic_ops.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/make_op.hpp>
#include <test.hpp>
void run_pass(migraphx::program& p)
{
migraphx::run_passes(p, {migraphx::inline_module{}, migraphx::dead_code_elimination{}});
}
TEST_CASE(cannot_inline_both)
{
auto create_program = [] {
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sd{migraphx::shape::float_type, {2, 3}};
auto x = mm->add_parameter("x", sd);
std::vector<float> one(sd.elements(), 1);
std::vector<float> two(sd.elements(), 2);
auto* then_smod = p.create_module("then_smod");
auto l1 = then_smod->add_literal(migraphx::literal{sd, one});
auto r1 = then_smod->add_instruction(migraphx::make_op("add"), x, l1);
then_smod->add_return({r1});
auto* else_smod = p.create_module("else_smod");
auto l2 = else_smod->add_literal(migraphx::literal{sd, two});
auto r2 = else_smod->add_instruction(migraphx::make_op("mul"), x, l2);
else_smod->add_return({r2});
migraphx::shape s_cond{migraphx::shape::bool_type, {1}};
auto cond = mm->add_parameter("cond", s_cond);
auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_smod, else_smod});
mm->add_return({ret});
return p;
};
auto p = create_program();
run_pass(p);
EXPECT(p == create_program());
}
TEST_CASE(cannot_inline_one)
{
auto create_program = [] {
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape cond_s{migraphx::shape::bool_type};
migraphx::shape s{migraphx::shape::float_type, {5}};
auto cond = mm->add_parameter("cond", cond_s);
auto x = mm->add_parameter("x", s);
auto* then_mod = p.create_module("If_0_if");
std::vector<float> data1 = {1, 2, 3, 4, 5};
auto l1 = then_mod->add_literal(migraphx::literal(s, data1));
then_mod->add_return({l1, x});
auto* else_mod = p.create_module("If_0_else");
std::vector<float> data2 = {5, 4, 3, 2, 1};
auto l2 = else_mod->add_literal(migraphx::literal(s, data2));
auto s2 = else_mod->add_instruction(migraphx::make_op("add"), x, l2);
else_mod->add_return({s2, l2});
auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod});
mm->add_return({ret});
return p;
};
auto p = create_program();
run_pass(p);
EXPECT(p == create_program());
}
TEST_CASE(inline_if_test)
{
auto create_program = [] {
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sc{migraphx::shape::bool_type, {1}};
auto cond = mm->add_literal(migraphx::literal(sc, {1}));
migraphx::shape s{migraphx::shape::float_type, {2, 3}};
std::vector<float> ones(s.elements(), 1.0f);
auto l1 = mm->add_literal(s, ones);
std::vector<float> rand = {-1.26487, -2.42279, 0.990835, 1.63072, 0.812238, -0.174946};
auto l2 = mm->add_literal(s, rand);
auto x = mm->add_parameter("x", s);
auto sm = mm->add_instruction(migraphx::make_op("add"), l1, x);
auto y = mm->add_parameter("y", s);
auto* then_mod = p.create_module("If_5_if");
auto rt = then_mod->add_instruction(migraphx::make_op("add"), x, sm);
then_mod->add_outline(s);
then_mod->add_return({rt});
auto* else_mod = p.create_module("If_5_else");
auto re = else_mod->add_instruction(migraphx::make_op("mul"), y, l2);
else_mod->add_return({re});
auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod});
auto r = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), ret);
mm->add_return({r});
return p;
};
auto create_inline = [] {
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {2, 3}};
std::vector<float> ones(s.elements(), 1.0f);
auto l1 = mm->add_literal(s, ones);
std::vector<float> rand = {-1.26487, -2.42279, 0.990835, 1.63072, 0.812238, -0.174946};
mm->add_literal(s, rand);
auto x = mm->add_parameter("x", s);
auto sm = mm->add_instruction(migraphx::make_op("add"), l1, x);
mm->add_parameter("y", s);
auto r = mm->add_instruction(migraphx::make_op("add"), x, sm);
mm->add_return({r});
return p;
};
auto p = create_program();
run_pass(p);
EXPECT(p == create_inline());
}
TEST_CASE(inline_else_test)
{
auto create_program = [] {
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sc{migraphx::shape::bool_type, {1}};
auto cond = mm->add_literal(migraphx::literal(sc, {0}));
migraphx::shape s{migraphx::shape::float_type, {2, 3}};
std::vector<float> ones(s.elements(), 1.0f);
auto l1 = mm->add_literal(s, ones);
std::vector<float> rand = {-1.26487, -2.42279, 0.990835, 1.63072, 0.812238, -0.174946};
auto l2 = mm->add_literal(s, rand);
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
auto* then_mod = p.create_module("If_5_if");
auto rt = then_mod->add_instruction(migraphx::make_op("add"), x, l1);
then_mod->add_return({rt});
auto* else_mod = p.create_module("If_5_else");
else_mod->add_parameter("e", s);
else_mod->add_literal(migraphx::literal(s, ones));
auto re = else_mod->add_instruction(migraphx::make_op("mul"), y, l2);
else_mod->add_return({re});
auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod});
auto r = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), ret);
mm->add_return({r});
return p;
};
auto create_inline = [] {
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {2, 3}};
std::vector<float> ones(s.elements(), 1.0f);
mm->add_literal(s, ones);
std::vector<float> rand = {-1.26487, -2.42279, 0.990835, 1.63072, 0.812238, -0.174946};
auto l2 = mm->add_literal(s, rand);
mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
auto r = mm->add_instruction(migraphx::make_op("mul"), y, l2);
mm->add_return({r});
return p;
};
auto p = create_program();
run_pass(p);
EXPECT(p == create_inline());
}
TEST_CASE(if_recursive_test)
{
auto create_program = [] {
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape cond_s{migraphx::shape::bool_type};
migraphx::shape xs{migraphx::shape::float_type, {2, 3}};
migraphx::shape ys{migraphx::shape::float_type, {3, 3}};
std::vector<float> datax = {1, 2, 3, 4, 5, 6};
std::vector<float> datay = {8, 7, 6, 5, 4, 3, 2, 1, 0};
auto lx = mm->add_literal(migraphx::literal(xs, datax));
auto ly = mm->add_literal(migraphx::literal(ys, datay));
auto cond = mm->add_literal(migraphx::literal(cond_s, {0}));
auto x1 = mm->add_parameter("x1", xs);
auto x2 = mm->add_parameter("x2", xs);
auto y2 = mm->add_parameter("y2", ys);
auto cond1 = mm->add_parameter("cond", cond_s);
auto* then_mod = p.create_module("If_5_if");
auto l1 = then_mod->add_literal(migraphx::literal(ys, datay));
auto a1 = then_mod->add_instruction(migraphx::make_op("add"), x1, lx);
then_mod->add_return({a1, l1});
auto* then_mod1 = p.create_module("If_6_if");
auto l11 = then_mod1->add_literal(migraphx::literal(ys, datay));
auto a11 = then_mod1->add_instruction(migraphx::make_op("add"), x2, lx);
then_mod1->add_return({a11, l11});
auto* else_mod1 = p.create_module("If_6_else");
auto l21 = else_mod1->add_literal(migraphx::literal(xs, datax));
auto a21 = else_mod1->add_instruction(migraphx::make_op("mul"), y2, ly);
else_mod1->add_return({l21, a21});
auto* else_mod = p.create_module("If_5_else");
auto l2 = else_mod->add_literal(migraphx::literal(xs, datax));
auto a2 =
else_mod->add_instruction(migraphx::make_op("if"), {cond1}, {then_mod1, else_mod1});
auto a3 =
else_mod->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), a2);
else_mod->add_return({l2, a3});
auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod});
auto r = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), ret);
mm->add_return({r});
return p;
};
auto create_inline = [] {
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape cond_s{migraphx::shape::bool_type};
migraphx::shape xs{migraphx::shape::float_type, {2, 3}};
migraphx::shape ys{migraphx::shape::float_type, {3, 3}};
std::vector<float> datax = {1, 2, 3, 4, 5, 6};
std::vector<float> datay = {8, 7, 6, 5, 4, 3, 2, 1, 0};
auto lx = mm->add_literal(migraphx::literal(xs, datax));
auto ly = mm->add_literal(migraphx::literal(ys, datay));
mm->add_parameter("x1", xs);
auto x2 = mm->add_parameter("x2", xs);
auto y2 = mm->add_parameter("y2", ys);
auto cond1 = mm->add_parameter("cond", cond_s);
auto* then_mod1 = p.create_module("If_6_if");
auto l11 = then_mod1->add_literal(migraphx::literal(ys, datay));
auto a11 = then_mod1->add_instruction(migraphx::make_op("add"), x2, lx);
then_mod1->add_return({a11, l11});
auto* else_mod1 = p.create_module("If_6_else");
auto l21 = else_mod1->add_literal(migraphx::literal(xs, datax));
auto a21 = else_mod1->add_instruction(migraphx::make_op("mul"), y2, ly);
else_mod1->add_return({l21, a21});
auto ret = mm->add_instruction(migraphx::make_op("if"), {cond1}, {then_mod1, else_mod1});
auto r = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), ret);
mm->add_return({r});
return p;
};
auto p = create_program();
run_pass(p);
EXPECT(p == create_inline());
}
TEST_CASE(if_recursive_cond0_test)
{
auto create_program = [] {
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape cond_s{migraphx::shape::bool_type};
migraphx::shape xs{migraphx::shape::float_type, {2, 3}};
migraphx::shape ys{migraphx::shape::float_type, {3, 3}};
std::vector<float> datax = {1, 2, 3, 4, 5, 6};
std::vector<float> datay = {8, 7, 6, 5, 4, 3, 2, 1, 0};
auto lx = mm->add_literal(migraphx::literal(xs, datax));
auto ly = mm->add_literal(migraphx::literal(ys, datay));
auto cond = mm->add_literal(migraphx::literal(cond_s, {0}));
auto x1 = mm->add_parameter("x1", xs);
auto x2 = mm->add_parameter("x2", xs);
auto y2 = mm->add_parameter("y2", ys);
auto* then_mod = p.create_module("If_5_if");
auto l1 = then_mod->add_literal(migraphx::literal(ys, datay));
auto a1 = then_mod->add_instruction(migraphx::make_op("add"), x1, lx);
then_mod->add_return({a1, l1});
auto* then_mod1 = p.create_module("If_6_if");
auto l11 = then_mod1->add_literal(migraphx::literal(ys, datay));
auto a11 = then_mod1->add_instruction(migraphx::make_op("add"), x2, lx);
then_mod1->add_return({a11, l11});
auto* else_mod1 = p.create_module("If_6_else");
auto l21 = else_mod1->add_literal(migraphx::literal(xs, datax));
auto a21 = else_mod1->add_instruction(migraphx::make_op("mul"), y2, ly);
else_mod1->add_return({l21, a21});
auto* else_mod = p.create_module("If_5_else");
auto l2 = else_mod->add_literal(migraphx::literal(xs, datax));
auto a2 =
else_mod->add_instruction(migraphx::make_op("if"), {cond}, {then_mod1, else_mod1});
auto a3 =
else_mod->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), a2);
else_mod->add_return({l2, a3});
auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod});
auto r = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), ret);
mm->add_return({r});
return p;
};
auto create_inline = [] {
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape cond_s{migraphx::shape::bool_type};
migraphx::shape xs{migraphx::shape::float_type, {2, 3}};
migraphx::shape ys{migraphx::shape::float_type, {3, 3}};
std::vector<float> datax = {1, 2, 3, 4, 5, 6};
std::vector<float> datay = {8, 7, 6, 5, 4, 3, 2, 1, 0};
mm->add_literal(migraphx::literal(xs, datax));
auto ly = mm->add_literal(migraphx::literal(ys, datay));
mm->add_parameter("x1", xs);
mm->add_parameter("x2", xs);
auto y2 = mm->add_parameter("y2", ys);
auto m = mm->add_instruction(migraphx::make_op("mul"), y2, ly);
mm->add_return({m});
return p;
};
auto p = create_program();
run_pass(p);
EXPECT(p == create_inline());
}
TEST_CASE(inline_tuple_true_test)
{
auto create_program = [] {
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sc{migraphx::shape::bool_type, {1}};
auto cond = mm->add_literal(migraphx::literal(sc, {1}));
migraphx::shape sd{migraphx::shape::float_type, {1}};
auto l1 = mm->add_literal(migraphx::literal(sd, {1}));
auto l2 = mm->add_literal(migraphx::literal(sd, {2}));
auto l3 = mm->add_literal(migraphx::literal(sd, {3}));
migraphx::shape sx{migraphx::shape::float_type, {1, 4}};
migraphx::shape sy{migraphx::shape::float_type, {3, 4}};
auto x = mm->add_parameter("x", sx);
auto y = mm->add_parameter("y", sy);
auto* then_mod = p.create_module("If_6_if");
auto m1 = then_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", {1, 4}}}), l1);
auto add0 = then_mod->add_instruction(migraphx::make_op("add"), x, m1);
auto m2 = then_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", {3, 4}}}), l2);
auto mul0 = then_mod->add_instruction(migraphx::make_op("mul"), y, m2);
then_mod->add_return({add0, mul0});
auto* else_mod = p.create_module("If_6_else");
auto me1 = else_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", {1, 4}}}), l3);
auto mul1 = else_mod->add_instruction(migraphx::make_op("mul"), x, me1);
auto me2 = else_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", {3, 4}}}), l3);
auto add1 = else_mod->add_instruction(migraphx::make_op("add"), y, me2);
else_mod->add_return({mul1, add1});
auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod});
auto r0 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), ret);
auto r1 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), ret);
mm->add_return({r0, r1});
return p;
};
auto create_inline = [] {
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sd{migraphx::shape::float_type, {1}};
auto l1 = mm->add_literal(migraphx::literal(sd, {1}));
auto l2 = mm->add_literal(migraphx::literal(sd, {2}));
mm->add_literal(migraphx::literal(sd, {3}));
migraphx::shape sx{migraphx::shape::float_type, {1, 4}};
migraphx::shape sy{migraphx::shape::float_type, {3, 4}};
auto x = mm->add_parameter("x", sx);
auto y = mm->add_parameter("y", sy);
auto m1 =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", {1, 4}}}), l1);
auto add = mm->add_instruction(migraphx::make_op("add"), x, m1);
auto m2 =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", {3, 4}}}), l2);
auto mul = mm->add_instruction(migraphx::make_op("mul"), y, m2);
mm->add_return({add, mul});
return p;
};
auto p = create_program();
run_pass(p);
EXPECT(p == create_inline());
}
TEST_CASE(inline_tuple_false_test)
{
auto create_program = [] {
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sc{migraphx::shape::bool_type, {1}};
auto cond = mm->add_literal(migraphx::literal(sc, {0}));
migraphx::shape sd{migraphx::shape::float_type, {1}};
auto l1 = mm->add_literal(migraphx::literal(sd, {1}));
auto l2 = mm->add_literal(migraphx::literal(sd, {2}));
auto l3 = mm->add_literal(migraphx::literal(sd, {3}));
migraphx::shape sx{migraphx::shape::float_type, {1, 4}};
migraphx::shape sy{migraphx::shape::float_type, {3, 4}};
auto x = mm->add_parameter("x", sx);
auto y = mm->add_parameter("y", sy);
auto* then_mod = p.create_module("If_6_if");
auto m1 = then_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", {1, 4}}}), l1);
auto add0 = then_mod->add_instruction(migraphx::make_op("add"), x, m1);
auto m2 = then_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", {3, 4}}}), l2);
auto mul0 = then_mod->add_instruction(migraphx::make_op("mul"), y, m2);
then_mod->add_return({add0, mul0});
auto* else_mod = p.create_module("If_6_else");
auto me1 = else_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", {1, 4}}}), l3);
auto mul1 = else_mod->add_instruction(migraphx::make_op("mul"), x, me1);
auto me2 = else_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", {3, 4}}}), l3);
auto add1 = else_mod->add_instruction(migraphx::make_op("add"), y, me2);
else_mod->add_return({mul1, add1});
auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod});
auto r0 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), ret);
auto r1 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), ret);
mm->add_return({r0, r1});
return p;
};
auto create_inline = [] {
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sc{migraphx::shape::bool_type, {1}};
migraphx::shape sd{migraphx::shape::float_type, {1}};
mm->add_literal(migraphx::literal(sd, {1}));
mm->add_literal(migraphx::literal(sd, {2}));
auto l3 = mm->add_literal(migraphx::literal(sd, {3}));
migraphx::shape sx{migraphx::shape::float_type, {1, 4}};
migraphx::shape sy{migraphx::shape::float_type, {3, 4}};
auto x = mm->add_parameter("x", sx);
auto y = mm->add_parameter("y", sy);
auto m1 =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", {1, 4}}}), l3);
auto mul = mm->add_instruction(migraphx::make_op("mul"), x, m1);
auto m2 =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", {3, 4}}}), l3);
auto add = mm->add_instruction(migraphx::make_op("add"), y, m2);
mm->add_return({mul, add});
return p;
};
auto p = create_program();
run_pass(p);
EXPECT(p == create_inline());
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/normalize_ops.hpp>
#include <migraphx/insert_pad.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/instruction.hpp>
#include <basic_ops.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/make_op.hpp>
#include <test.hpp>
void run_pass(migraphx::module& m)
{
migraphx::run_passes(
m, {migraphx::normalize_ops{}, migraphx::insert_pad{}, migraphx::dead_code_elimination{}});
}
migraphx::instruction_ref
create_im2col(migraphx::instruction_ref& l_img, size_t channels, migraphx::module& m)
{
size_t f[2] = {1, 1};
std::vector<int32_t> weights(channels * f[0] * f[1]);
migraphx::shape s_weights{migraphx::shape::int32_type, {1, channels, f[0], f[1]}};
auto l_weights = m.add_literal(migraphx::literal{s_weights, weights});
return m.add_instruction(
migraphx::make_op("im2col", {{"padding", {0, 0, 1, 1}}}), l_img, l_weights);
}
migraphx::instruction_ref
create_conv(migraphx::instruction_ref& l_img,
size_t channels,
migraphx::module& m,
migraphx::op::padding_mode_t padding_mode = migraphx::op::padding_mode_t::default_)
{
migraphx::shape s_weights{migraphx::shape::int32_type, {4, channels, 3, 3}};
std::vector<int32_t> weights(4 * channels * 3 * 3);
auto l_weights = m.add_literal(migraphx::literal{s_weights, weights});
migraphx::op::convolution op;
op.padding_mode = padding_mode;
op.padding = {0, 0, 1, 1};
return m.add_instruction(op, l_img, l_weights);
}
TEST_CASE(rewrite_pad)
{
migraphx::module m;
size_t img_dim[2] = {2, 2};
size_t channels = 1;
std::vector<int32_t> input(channels * img_dim[0] * img_dim[1]);
std::iota(input.begin(), input.end(), 0);
migraphx::shape s_img{migraphx::shape::int32_type, {1, channels, img_dim[0], img_dim[1]}};
auto l_img = m.add_literal(migraphx::literal{s_img, input});
auto l0 = create_im2col(l_img, channels, m);
auto l1 = create_conv(l_img, channels, m);
auto l2 = m.add_instruction(
migraphx::make_op("pooling", {{"mode", "max"}, {"padding", {0, 0, 1, 1}}}), l_img);
m.add_instruction(migraphx::make_op("identity"), l0, l1, l2);
run_pass(m);
EXPECT(std::any_of(
m.begin(), m.end(), [](const migraphx::instruction& ins) { return ins.name() == "pad"; }));
}
TEST_CASE(rewrite_pad_symmetric)
{
migraphx::module m;
size_t img_dim[2] = {2, 2};
size_t channels = 1;
std::vector<int32_t> input(channels * img_dim[0] * img_dim[1]);
std::iota(input.begin(), input.end(), 0);
migraphx::shape s_img{migraphx::shape::int32_type, {1, channels, img_dim[0], img_dim[1]}};
auto l_img = m.add_literal(migraphx::literal{s_img, input});
m.add_instruction(migraphx::make_op("pooling", {{"mode", "max"}, {"padding", {1, 1, 1, 1}}}),
l_img);
run_pass(m);
EXPECT(std::none_of(
m.begin(), m.end(), [](const migraphx::instruction& ins) { return ins.name() == "pad"; }));
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
#include "test.hpp"
int main() {}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
......@@ -22,576 +22,498 @@ migraphx::match::matcher_result find_match(migraphx::module& modl, M&& m)
void match1()
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l = mm->add_literal(1);
auto m = match::standard_shape();
auto r = find_match(*mm, m);
migraphx::module mm;
auto l = mm.add_literal(1);
auto m = match::standard_shape();
auto r = find_match(mm, m);
EXPECT(bool{r.result == l});
}
TEST_CASE(match_name1)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto sum = mm->add_instruction(sum_op{}, one, two);
mm->add_instruction(pass_op{}, sum);
migraphx::module mm;
auto one = mm.add_literal(1);
auto two = mm.add_literal(2);
auto sum = mm.add_instruction(sum_op{}, one, two);
mm.add_instruction(pass_op{}, sum);
auto m = match::name("sum");
auto r = find_match(*mm, m);
auto r = find_match(mm, m);
EXPECT(bool{r.result == sum});
}
TEST_CASE(match_name2)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto sum = mm->add_instruction(sum_op{}, one, two);
mm->add_instruction(pass_op{}, sum);
migraphx::module mm;
auto one = mm.add_literal(1);
auto two = mm.add_literal(2);
auto sum = mm.add_instruction(sum_op{}, one, two);
mm.add_instruction(pass_op{}, sum);
auto m = match::name("min");
auto r = find_match(*mm, m);
EXPECT(bool{r.result == mm->end()});
auto r = find_match(mm, m);
EXPECT(bool{r.result == mm.end()});
}
TEST_CASE(match_name3)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto sum = mm->add_instruction(sum_op{}, one, two);
mm->add_instruction(pass_op{}, sum);
migraphx::module mm;
auto one = mm.add_literal(1);
auto two = mm.add_literal(2);
auto sum = mm.add_instruction(sum_op{}, one, two);
mm.add_instruction(pass_op{}, sum);
auto m = match::name("sum")(match::standard_shape());
auto r = find_match(*mm, m);
auto r = find_match(mm, m);
EXPECT(bool{r.result == sum});
}
TEST_CASE(match_arg1)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto sum = mm->add_instruction(sum_op{}, one, two);
mm->add_instruction(pass_op{}, sum);
migraphx::module mm;
auto one = mm.add_literal(1);
auto two = mm.add_literal(2);
auto sum = mm.add_instruction(sum_op{}, one, two);
mm.add_instruction(pass_op{}, sum);
auto m = match::name("sum")(match::arg(0)(match::name("@literal")), match::standard_shape());
auto r = find_match(*mm, m);
auto r = find_match(mm, m);
EXPECT(bool{r.result == sum});
}
TEST_CASE(match_arg2)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto sum = mm->add_instruction(sum_op{}, one, two);
mm->add_instruction(pass_op{}, sum);
migraphx::module mm;
auto one = mm.add_literal(1);
auto two = mm.add_literal(2);
auto sum = mm.add_instruction(sum_op{}, one, two);
mm.add_instruction(pass_op{}, sum);
auto m = match::name("sum")(match::arg(0)(match::name("sum")), match::standard_shape());
auto r = find_match(*mm, m);
EXPECT(bool{r.result == mm->end()});
auto r = find_match(mm, m);
EXPECT(bool{r.result == mm.end()});
}
TEST_CASE(match_arg3)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto sum = mm->add_instruction(sum_op{}, one, two);
mm->add_instruction(pass_op{}, sum);
migraphx::module mm;
auto one = mm.add_literal(1);
auto two = mm.add_literal(2);
auto sum = mm.add_instruction(sum_op{}, one, two);
mm.add_instruction(pass_op{}, sum);
auto m = match::name("sum")(match::arg(1)(match::name("@literal")), match::standard_shape());
auto r = find_match(*mm, m);
auto r = find_match(mm, m);
EXPECT(bool{r.result == sum});
}
TEST_CASE(match_arg4)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto sum = mm->add_instruction(sum_op{}, one, two);
auto pass = mm->add_instruction(pass_op{}, sum);
migraphx::module mm;
auto one = mm.add_literal(1);
auto two = mm.add_literal(2);
auto sum = mm.add_instruction(sum_op{}, one, two);
auto pass = mm.add_instruction(pass_op{}, sum);
auto m = match::name("pass")(match::arg(0)(match::name("sum")), match::standard_shape());
auto r = find_match(*mm, m);
auto r = find_match(mm, m);
EXPECT(bool{r.result == pass});
}
TEST_CASE(match_arg5)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto sum = mm->add_instruction(sum_op{}, one, two);
mm->add_instruction(pass_op{}, sum);
migraphx::module mm;
auto one = mm.add_literal(1);
auto two = mm.add_literal(2);
auto sum = mm.add_instruction(sum_op{}, one, two);
mm.add_instruction(pass_op{}, sum);
auto m = match::name("pass")(match::arg(1)(match::name("sum")), match::standard_shape());
auto r = find_match(*mm, m);
EXPECT(bool{r.result == mm->end()});
auto r = find_match(mm, m);
EXPECT(bool{r.result == mm.end()});
}
TEST_CASE(match_arg6)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto sum = mm->add_instruction(sum_op{}, one, two);
mm->add_instruction(pass_op{}, sum);
migraphx::module mm;
auto one = mm.add_literal(1);
auto two = mm.add_literal(2);
auto sum = mm.add_instruction(sum_op{}, one, two);
mm.add_instruction(pass_op{}, sum);
auto m = match::name("sum")(match::arg(0)(match::name("@literal")));
auto r = find_match(*mm, m);
auto r = find_match(mm, m);
EXPECT(bool{r.result == sum});
}
TEST_CASE(match_arg7)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto sum = mm->add_instruction(sum_op{}, one, two);
mm->add_instruction(pass_op{}, sum);
migraphx::module mm;
auto one = mm.add_literal(1);
auto two = mm.add_literal(2);
auto sum = mm.add_instruction(sum_op{}, one, two);
mm.add_instruction(pass_op{}, sum);
auto m = match::name("sum")(match::arg(0)(match::name("@literal")),
match::arg(1)(match::name("@literal")));
auto r = find_match(*mm, m);
auto r = find_match(mm, m);
EXPECT(bool{r.result == sum});
}
TEST_CASE(match_arg8)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto sum = mm->add_instruction(sum_op{}, one, two);
mm->add_instruction(pass_op{}, sum);
migraphx::module mm;
auto one = mm.add_literal(1);
auto two = mm.add_literal(2);
auto sum = mm.add_instruction(sum_op{}, one, two);
mm.add_instruction(pass_op{}, sum);
auto m = match::name("sum")(match::all_of(match::arg(0)(match::name("@literal")),
match::arg(1)(match::name("@literal"))),
match::standard_shape());
auto r = find_match(*mm, m);
auto r = find_match(mm, m);
EXPECT(bool{r.result == sum});
}
TEST_CASE(match_nargs1)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto sum = mm->add_instruction(sum_op{}, one, two);
mm->add_instruction(pass_op{}, sum);
migraphx::module mm;
auto one = mm.add_literal(1);
auto two = mm.add_literal(2);
auto sum = mm.add_instruction(sum_op{}, one, two);
mm.add_instruction(pass_op{}, sum);
auto m = match::name("sum")(match::nargs(2));
auto r = find_match(*mm, m);
auto r = find_match(mm, m);
EXPECT(bool{r.result == sum});
}
TEST_CASE(match_nargs2)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto sum = mm->add_instruction(sum_op{}, one, two);
mm->add_instruction(pass_op{}, sum);
migraphx::module mm;
auto one = mm.add_literal(1);
auto two = mm.add_literal(2);
auto sum = mm.add_instruction(sum_op{}, one, two);
mm.add_instruction(pass_op{}, sum);
auto m = match::name("sum")(match::nargs(2), match::standard_shape());
auto r = find_match(*mm, m);
auto r = find_match(mm, m);
EXPECT(bool{r.result == sum});
}
TEST_CASE(match_nargs3)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto sum = mm->add_instruction(sum_op{}, one, two);
mm->add_instruction(pass_op{}, sum);
migraphx::module mm;
auto one = mm.add_literal(1);
auto two = mm.add_literal(2);
auto sum = mm.add_instruction(sum_op{}, one, two);
mm.add_instruction(pass_op{}, sum);
auto m = match::name("sum")(match::all_of(match::nargs(2)));
auto r = find_match(*mm, m);
auto r = find_match(mm, m);
EXPECT(bool{r.result == sum});
}
TEST_CASE(match_args1)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto sum = mm->add_instruction(sum_op{}, one, two);
mm->add_instruction(pass_op{}, sum);
migraphx::module mm;
auto one = mm.add_literal(1);
auto two = mm.add_literal(2);
auto sum = mm.add_instruction(sum_op{}, one, two);
mm.add_instruction(pass_op{}, sum);
auto m = match::name("sum")(match::args(match::name("@literal"), match::name("@literal")),
match::standard_shape());
auto r = find_match(*mm, m);
auto r = find_match(mm, m);
EXPECT(bool{r.result == sum});
}
TEST_CASE(match_args2)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto sum = mm->add_instruction(sum_op{}, one, two);
mm->add_instruction(pass_op{}, sum);
migraphx::module mm;
auto one = mm.add_literal(1);
auto two = mm.add_literal(2);
auto sum = mm.add_instruction(sum_op{}, one, two);
mm.add_instruction(pass_op{}, sum);
auto m = match::name("sum")(match::args(match::name("@literal"), match::name("sum")),
match::standard_shape());
auto r = find_match(*mm, m);
EXPECT(bool{r.result == mm->end()});
auto r = find_match(mm, m);
EXPECT(bool{r.result == mm.end()});
}
TEST_CASE(match_args3)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto sum = mm->add_instruction(sum_op{}, one, two);
mm->add_instruction(pass_op{}, sum);
migraphx::module mm;
auto one = mm.add_literal(1);
auto two = mm.add_literal(2);
auto sum = mm.add_instruction(sum_op{}, one, two);
mm.add_instruction(pass_op{}, sum);
auto m = match::name("sum")(match::args(match::name("@literal")), match::standard_shape());
auto r = find_match(*mm, m);
EXPECT(bool{r.result == mm->end()});
auto r = find_match(mm, m);
EXPECT(bool{r.result == mm.end()});
}
TEST_CASE(match_args4)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto sum1 = mm->add_instruction(sum_op{}, one, two);
auto sum2 = mm->add_instruction(sum_op{}, sum1, two);
mm->add_instruction(pass_op{}, sum2);
migraphx::module mm;
auto one = mm.add_literal(1);
auto two = mm.add_literal(2);
auto sum1 = mm.add_instruction(sum_op{}, one, two);
auto sum2 = mm.add_instruction(sum_op{}, sum1, two);
mm.add_instruction(pass_op{}, sum2);
auto m = match::name("sum")(match::args(match::name("sum"), match::name("@literal")),
match::standard_shape());
auto r = find_match(*mm, m);
auto r = find_match(mm, m);
EXPECT(bool{r.result == sum2});
}
TEST_CASE(match_args5)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto sum = mm->add_instruction(sum_op{}, one, two);
mm->add_instruction(pass_op{}, sum);
migraphx::module mm;
auto one = mm.add_literal(1);
auto two = mm.add_literal(2);
auto sum = mm.add_instruction(sum_op{}, one, two);
mm.add_instruction(pass_op{}, sum);
auto m = match::name("sum")(match::args(match::name("sum"), match::name("@literal")),
match::standard_shape());
auto r = find_match(*mm, m);
EXPECT(bool{r.result == mm->end()});
auto r = find_match(mm, m);
EXPECT(bool{r.result == mm.end()});
}
TEST_CASE(match_args6)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto sum = mm->add_instruction(sum_op{}, one, two);
auto pass = mm->add_instruction(pass_op{}, sum);
migraphx::module mm;
auto one = mm.add_literal(1);
auto two = mm.add_literal(2);
auto sum = mm.add_instruction(sum_op{}, one, two);
auto pass = mm.add_instruction(pass_op{}, sum);
auto m = match::name("pass")(match::args(match::name("sum")), match::standard_shape());
auto r = find_match(*mm, m);
auto r = find_match(mm, m);
EXPECT(bool{r.result == pass});
}
TEST_CASE(match_args7)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto sum = mm->add_instruction(sum_op{}, one, two);
auto pass = mm->add_instruction(pass_op{}, sum);
migraphx::module mm;
auto one = mm.add_literal(1);
auto two = mm.add_literal(2);
auto sum = mm.add_instruction(sum_op{}, one, two);
auto pass = mm.add_instruction(pass_op{}, sum);
auto m = match::name("pass")(match::args(match::name("sum")(match::args(
match::name("@literal"), match::name("@literal")))),
match::standard_shape());
auto r = find_match(*mm, m);
auto r = find_match(mm, m);
EXPECT(bool{r.result == pass});
}
TEST_CASE(match_either_args1)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto sum1 = mm->add_instruction(sum_op{}, one, two);
auto sum2 = mm->add_instruction(sum_op{}, sum1, two);
mm->add_instruction(pass_op{}, sum2);
migraphx::module mm;
auto one = mm.add_literal(1);
auto two = mm.add_literal(2);
auto sum1 = mm.add_instruction(sum_op{}, one, two);
auto sum2 = mm.add_instruction(sum_op{}, sum1, two);
mm.add_instruction(pass_op{}, sum2);
auto m =
match::name("sum")(match::either_arg(0, 1)(match::name("sum"), match::name("@literal")));
auto r = find_match(*mm, m);
auto r = find_match(mm, m);
EXPECT(bool{r.result == sum2});
}
TEST_CASE(match_either_args2)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto sum1 = mm->add_instruction(sum_op{}, one, two);
auto sum2 = mm->add_instruction(sum_op{}, sum1, two);
mm->add_instruction(pass_op{}, sum2);
migraphx::module mm;
auto one = mm.add_literal(1);
auto two = mm.add_literal(2);
auto sum1 = mm.add_instruction(sum_op{}, one, two);
auto sum2 = mm.add_instruction(sum_op{}, sum1, two);
mm.add_instruction(pass_op{}, sum2);
auto m =
match::name("sum")(match::either_arg(0, 1)(match::name("@literal"), match::name("sum")));
auto r = find_match(*mm, m);
auto r = find_match(mm, m);
EXPECT(bool{r.result == sum2});
}
TEST_CASE(match_either_args3)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto sum1 = mm->add_instruction(sum_op{}, one, two);
auto sum2 = mm->add_instruction(sum_op{}, sum1, two);
mm->add_instruction(pass_op{}, sum2);
migraphx::module mm;
auto one = mm.add_literal(1);
auto two = mm.add_literal(2);
auto sum1 = mm.add_instruction(sum_op{}, one, two);
auto sum2 = mm.add_instruction(sum_op{}, sum1, two);
mm.add_instruction(pass_op{}, sum2);
auto m =
match::name("sum")(match::either_arg(0, 1)(match::name("pass"), match::name("@literal")));
auto r = find_match(*mm, m);
EXPECT(bool{r.result == mm->end()});
auto r = find_match(mm, m);
EXPECT(bool{r.result == mm.end()});
}
TEST_CASE(match_either_args_any1)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto sum1 = mm->add_instruction(sum_op{}, one, two);
auto sum2 = mm->add_instruction(sum_op{}, sum1, two);
mm->add_instruction(pass_op{}, sum2);
migraphx::module mm;
auto one = mm.add_literal(1);
auto two = mm.add_literal(2);
auto sum1 = mm.add_instruction(sum_op{}, one, two);
auto sum2 = mm.add_instruction(sum_op{}, sum1, two);
mm.add_instruction(pass_op{}, sum2);
auto m =
match::name("sum")(match::either_arg(0, 1)(match::any().bind("x"), match::any().bind("y")));
auto r = find_match(*mm, m);
auto r = find_match(mm, m);
EXPECT(bool{r.result == sum1});
EXPECT(bool{r.instructions.at("x") != r.instructions.at("y")});
}
TEST_CASE(match_either_args_any2)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto sum1 = mm->add_instruction(sum_op{}, one, two);
auto sum2 = mm->add_instruction(sum_op{}, sum1, two);
mm->add_instruction(pass_op{}, sum2);
migraphx::module mm;
auto one = mm.add_literal(1);
auto two = mm.add_literal(2);
auto sum1 = mm.add_instruction(sum_op{}, one, two);
auto sum2 = mm.add_instruction(sum_op{}, sum1, two);
mm.add_instruction(pass_op{}, sum2);
auto m = match::name("sum")(
match::either_arg(0, 1)(match::any().bind("x"), match::name("@literal").bind("y")));
auto r = find_match(*mm, m);
auto r = find_match(mm, m);
EXPECT(bool{r.result == sum1});
EXPECT(bool{r.instructions.at("x") != r.instructions.at("y")});
}
TEST_CASE(match_either_args_any3)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto sum1 = mm->add_instruction(sum_op{}, one, two);
auto sum2 = mm->add_instruction(sum_op{}, sum1, two);
mm->add_instruction(pass_op{}, sum2);
migraphx::module mm;
auto one = mm.add_literal(1);
auto two = mm.add_literal(2);
auto sum1 = mm.add_instruction(sum_op{}, one, two);
auto sum2 = mm.add_instruction(sum_op{}, sum1, two);
mm.add_instruction(pass_op{}, sum2);
auto m = match::name("sum")(
match::either_arg(0, 1)(match::name("@literal").bind("x"), match::any().bind("y")));
auto r = find_match(*mm, m);
auto r = find_match(mm, m);
EXPECT(bool{r.result == sum1});
EXPECT(bool{r.instructions.at("x") != r.instructions.at("y")});
}
TEST_CASE(match_either_args_any4)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto sum1 = mm->add_instruction(sum_op{}, one, two);
auto sum2 = mm->add_instruction(sum_op{}, sum1, two);
mm->add_instruction(pass_op{}, sum2);
migraphx::module mm;
auto one = mm.add_literal(1);
auto two = mm.add_literal(2);
auto sum1 = mm.add_instruction(sum_op{}, one, two);
auto sum2 = mm.add_instruction(sum_op{}, sum1, two);
mm.add_instruction(pass_op{}, sum2);
auto m = match::name("sum")(
match::either_arg(0, 1)(match::name("sum").bind("x"), match::any().bind("y")));
auto r = find_match(*mm, m);
auto r = find_match(mm, m);
EXPECT(bool{r.result == sum2});
EXPECT(bool{r.instructions.at("x") != r.instructions.at("y")});
}
TEST_CASE(match_either_args_any5)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto sum1 = mm->add_instruction(sum_op{}, one, two);
auto sum2 = mm->add_instruction(sum_op{}, sum1, two);
mm->add_instruction(pass_op{}, sum2);
migraphx::module mm;
auto one = mm.add_literal(1);
auto two = mm.add_literal(2);
auto sum1 = mm.add_instruction(sum_op{}, one, two);
auto sum2 = mm.add_instruction(sum_op{}, sum1, two);
mm.add_instruction(pass_op{}, sum2);
auto m = match::name("sum")(
match::either_arg(0, 1)(match::any().bind("x"), match::name("sum").bind("y")));
auto r = find_match(*mm, m);
auto r = find_match(mm, m);
EXPECT(bool{r.result == sum2});
EXPECT(bool{r.instructions.at("x") != r.instructions.at("y")});
}
TEST_CASE(match_all_of1)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto sum = mm->add_instruction(sum_op{}, one, two);
mm->add_instruction(pass_op{}, sum);
migraphx::module mm;
auto one = mm.add_literal(1);
auto two = mm.add_literal(2);
auto sum = mm.add_instruction(sum_op{}, one, two);
mm.add_instruction(pass_op{}, sum);
auto m = match::name("sum")(match::all_of(match::arg(0)(match::name("@literal")),
match::arg(1)(match::name("@literal"))));
auto r = find_match(*mm, m);
auto r = find_match(mm, m);
EXPECT(bool{r.result == sum});
}
TEST_CASE(match_all_of2)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto sum = mm->add_instruction(sum_op{}, one, two);
mm->add_instruction(pass_op{}, sum);
migraphx::module mm;
auto one = mm.add_literal(1);
auto two = mm.add_literal(2);
auto sum = mm.add_instruction(sum_op{}, one, two);
mm.add_instruction(pass_op{}, sum);
auto m = match::name("sum")(
match::all_of(match::arg(0)(match::name("sum")), match::arg(1)(match::name("@literal"))));
auto r = find_match(*mm, m);
EXPECT(bool{r.result == mm->end()});
auto r = find_match(mm, m);
EXPECT(bool{r.result == mm.end()});
}
TEST_CASE(match_all_of3)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto sum = mm->add_instruction(sum_op{}, one, two);
mm->add_instruction(pass_op{}, sum);
migraphx::module mm;
auto one = mm.add_literal(1);
auto two = mm.add_literal(2);
auto sum = mm.add_instruction(sum_op{}, one, two);
mm.add_instruction(pass_op{}, sum);
auto m = match::name("sum")(match::all_of(match::all_of(
match::arg(0)(match::name("@literal")), match::arg(1)(match::name("@literal")))));
auto r = find_match(*mm, m);
auto r = find_match(mm, m);
EXPECT(bool{r.result == sum});
}
TEST_CASE(match_lazy_any_of)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
mm->add_instruction(pass_op{}, one);
migraphx::module mm;
auto one = mm.add_literal(1);
mm.add_instruction(pass_op{}, one);
auto m = match::any_of(match::any(), throws());
auto r = find_match(*mm, m);
auto r = find_match(mm, m);
EXPECT(bool{r.result == one});
}
TEST_CASE(match_lazy_all_of)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
mm->add_instruction(pass_op{}, one);
migraphx::module mm;
auto one = mm.add_literal(1);
mm.add_instruction(pass_op{}, one);
auto m = match::all_of(match::none(), throws());
auto r = find_match(*mm, m);
EXPECT(bool{r.result == mm->end()});
auto r = find_match(mm, m);
EXPECT(bool{r.result == mm.end()});
}
TEST_CASE(match_lazy_none_of)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
mm->add_instruction(pass_op{}, one);
migraphx::module mm;
auto one = mm.add_literal(1);
mm.add_instruction(pass_op{}, one);
auto m = match::none_of(match::any(), throws());
auto r = find_match(*mm, m);
EXPECT(bool{r.result == mm->end()});
auto r = find_match(mm, m);
EXPECT(bool{r.result == mm.end()});
}
TEST_CASE(match_any_of1)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto sum = mm->add_instruction(sum_op{}, one, two);
mm->add_instruction(pass_op{}, sum);
migraphx::module mm;
auto one = mm.add_literal(1);
auto two = mm.add_literal(2);
auto sum = mm.add_instruction(sum_op{}, one, two);
mm.add_instruction(pass_op{}, sum);
auto m = match::name("sum")(
match::any_of(match::arg(0)(match::name("sum")), match::arg(1)(match::name("@literal"))));
auto r = find_match(*mm, m);
auto r = find_match(mm, m);
EXPECT(bool{r.result == sum});
}
TEST_CASE(match_any_of2)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto sum = mm->add_instruction(sum_op{}, one, two);
mm->add_instruction(pass_op{}, sum);
migraphx::module mm;
auto one = mm.add_literal(1);
auto two = mm.add_literal(2);
auto sum = mm.add_instruction(sum_op{}, one, two);
mm.add_instruction(pass_op{}, sum);
auto m = match::name("sum")(
match::any_of(match::arg(0)(match::name("sum")), match::arg(1)(match::name("sum"))));
auto r = find_match(*mm, m);
EXPECT(bool{r.result == mm->end()});
auto r = find_match(mm, m);
EXPECT(bool{r.result == mm.end()});
}
TEST_CASE(match_any_of_lazy1)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto sum = mm->add_instruction(sum_op{}, one, two);
mm->add_instruction(pass_op{}, sum);
migraphx::module mm;
auto one = mm.add_literal(1);
auto two = mm.add_literal(2);
auto sum = mm.add_instruction(sum_op{}, one, two);
mm.add_instruction(pass_op{}, sum);
auto m = match::name("sum")(
match::any_of(match::args(match::any(), match::any()).bind("x"),
match::args(match::name("sum"), match::name("sum")).bind("y")));
auto r = find_match(*mm, m);
auto r = find_match(mm, m);
EXPECT(bool{r.result == sum});
EXPECT(migraphx::contains(r.instructions, "x"));
EXPECT(bool{r.instructions["x"] == sum});
......@@ -600,17 +522,15 @@ TEST_CASE(match_any_of_lazy1)
TEST_CASE(match_any_of_lazy2)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto sum = mm->add_instruction(sum_op{}, one, two);
mm->add_instruction(pass_op{}, sum);
migraphx::module mm;
auto one = mm.add_literal(1);
auto two = mm.add_literal(2);
auto sum = mm.add_instruction(sum_op{}, one, two);
mm.add_instruction(pass_op{}, sum);
auto m = match::name("sum")(
match::any_of(match::args(match::name("@literal"), match::name("@literal")).bind("x"),
match::args(match::any(), match::any()).bind("y")));
auto r = find_match(*mm, m);
auto r = find_match(mm, m);
EXPECT(bool{r.result == sum});
EXPECT(migraphx::contains(r.instructions, "x"));
EXPECT(bool{r.instructions["x"] == sum});
......@@ -619,17 +539,15 @@ TEST_CASE(match_any_of_lazy2)
TEST_CASE(match_any_of_lazy3)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto sum = mm->add_instruction(sum_op{}, one, two);
mm->add_instruction(pass_op{}, sum);
migraphx::module mm;
auto one = mm.add_literal(1);
auto two = mm.add_literal(2);
auto sum = mm.add_instruction(sum_op{}, one, two);
mm.add_instruction(pass_op{}, sum);
auto m = match::name("sum")(
match::any_of(match::args(match::any(), match::any()).bind("x"),
match::args(match::name("@literal"), match::name("@literal")).bind("y")));
auto r = find_match(*mm, m);
auto r = find_match(mm, m);
EXPECT(bool{r.result == sum});
EXPECT(migraphx::contains(r.instructions, "x"));
EXPECT(bool{r.instructions["x"] == sum});
......@@ -638,17 +556,15 @@ TEST_CASE(match_any_of_lazy3)
TEST_CASE(match_any_of_lazy4)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto sum = mm->add_instruction(sum_op{}, one, two);
mm->add_instruction(pass_op{}, sum);
migraphx::module mm;
auto one = mm.add_literal(1);
auto two = mm.add_literal(2);
auto sum = mm.add_instruction(sum_op{}, one, two);
mm.add_instruction(pass_op{}, sum);
auto m = match::name("sum")(match::any_of(
match::args(match::name("@literal").bind("x1"), match::name("@literal").bind("y1")),
match::args(match::any().bind("x2"), match::any().bind("y2"))));
auto r = find_match(*mm, m);
auto r = find_match(mm, m);
EXPECT(bool{r.result == sum});
EXPECT(migraphx::contains(r.instructions, "x1"));
EXPECT(migraphx::contains(r.instructions, "y1"));
......@@ -660,17 +576,15 @@ TEST_CASE(match_any_of_lazy4)
TEST_CASE(match_any_of_lazy5)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto sum = mm->add_instruction(sum_op{}, one, two);
mm->add_instruction(pass_op{}, sum);
migraphx::module mm;
auto one = mm.add_literal(1);
auto two = mm.add_literal(2);
auto sum = mm.add_instruction(sum_op{}, one, two);
mm.add_instruction(pass_op{}, sum);
auto m = match::name("sum")(match::any_of(
match::args(match::any().bind("x1"), match::any().bind("y1")),
match::args(match::name("@literal").bind("x2"), match::name("@literal").bind("y2"))));
auto r = find_match(*mm, m);
auto r = find_match(mm, m);
EXPECT(bool{r.result == sum});
EXPECT(migraphx::contains(r.instructions, "x1"));
EXPECT(migraphx::contains(r.instructions, "y1"));
......@@ -682,194 +596,170 @@ TEST_CASE(match_any_of_lazy5)
TEST_CASE(match_none_of1)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto sum = mm->add_instruction(sum_op{}, one, two);
mm->add_instruction(pass_op{}, sum);
migraphx::module mm;
auto one = mm.add_literal(1);
auto two = mm.add_literal(2);
auto sum = mm.add_instruction(sum_op{}, one, two);
mm.add_instruction(pass_op{}, sum);
auto m = match::name("sum")(
match::none_of(match::arg(0)(match::name("sum")), match::arg(1)(match::name("sum"))));
auto r = find_match(*mm, m);
auto r = find_match(mm, m);
EXPECT(bool{r.result == sum});
}
TEST_CASE(match_none_of2)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto sum = mm->add_instruction(sum_op{}, one, two);
mm->add_instruction(pass_op{}, sum);
migraphx::module mm;
auto one = mm.add_literal(1);
auto two = mm.add_literal(2);
auto sum = mm.add_instruction(sum_op{}, one, two);
mm.add_instruction(pass_op{}, sum);
auto m = match::name("sum")(match::none_of(match::arg(0)(match::name("@literal")),
match::arg(1)(match::name("@literal"))));
auto r = find_match(*mm, m);
EXPECT(bool{r.result == mm->end()});
auto r = find_match(mm, m);
EXPECT(bool{r.result == mm.end()});
}
TEST_CASE(match_output1)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto minus = mm->add_instruction(minus_op{}, two, one);
auto sum = mm->add_instruction(sum_op{}, minus, two);
mm->add_instruction(pass_op{}, sum);
migraphx::module mm;
auto one = mm.add_literal(1);
auto two = mm.add_literal(2);
auto minus = mm.add_instruction(minus_op{}, two, one);
auto sum = mm.add_instruction(sum_op{}, minus, two);
mm.add_instruction(pass_op{}, sum);
auto m = match::name("minus")(match::output(match::name("sum")));
auto r = find_match(*mm, m);
auto r = find_match(mm, m);
EXPECT(bool{r.result == minus});
}
TEST_CASE(match_output2)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto minus = mm->add_instruction(minus_op{}, two, one);
auto sum = mm->add_instruction(sum_op{}, minus, two);
mm->add_instruction(pass_op{}, sum);
migraphx::module mm;
auto one = mm.add_literal(1);
auto two = mm.add_literal(2);
auto minus = mm.add_instruction(minus_op{}, two, one);
auto sum = mm.add_instruction(sum_op{}, minus, two);
mm.add_instruction(pass_op{}, sum);
auto m = match::name("@literal")(match::output(match::name("sum")));
auto r = find_match(*mm, m);
EXPECT(bool{r.result == mm->end()});
auto r = find_match(mm, m);
EXPECT(bool{r.result == mm.end()});
}
TEST_CASE(match_skip_output1)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto minus = mm->add_instruction(minus_op{}, two, one);
auto sum = mm->add_instruction(sum_op{}, minus, two);
mm->add_instruction(pass_op{}, sum);
migraphx::module mm;
auto one = mm.add_literal(1);
auto two = mm.add_literal(2);
auto minus = mm.add_instruction(minus_op{}, two, one);
auto sum = mm.add_instruction(sum_op{}, minus, two);
mm.add_instruction(pass_op{}, sum);
auto m = match::name("minus")(match::skip_output(match::name("pass"))(match::name("sum")));
auto r = find_match(*mm, m);
auto r = find_match(mm, m);
EXPECT(bool{r.result == minus});
}
TEST_CASE(match_skip_output2)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto minus = mm->add_instruction(minus_op{}, two, one);
auto minus_pass = mm->add_instruction(pass_op{}, minus);
auto sum = mm->add_instruction(sum_op{}, minus_pass, two);
mm->add_instruction(pass_op{}, sum);
migraphx::module mm;
auto one = mm.add_literal(1);
auto two = mm.add_literal(2);
auto minus = mm.add_instruction(minus_op{}, two, one);
auto minus_pass = mm.add_instruction(pass_op{}, minus);
auto sum = mm.add_instruction(sum_op{}, minus_pass, two);
mm.add_instruction(pass_op{}, sum);
auto m = match::name("minus")(match::skip_output(match::name("pass"))(match::name("sum")));
auto r = find_match(*mm, m);
auto r = find_match(mm, m);
EXPECT(bool{r.result == minus});
}
TEST_CASE(match_skip_output3)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto minus = mm->add_instruction(minus_op{}, two, one);
auto minus_pass1 = mm->add_instruction(pass_op{}, minus);
auto minus_pass2 = mm->add_instruction(pass_op{}, minus_pass1);
auto minus_pass3 = mm->add_instruction(pass_op{}, minus_pass2);
auto sum = mm->add_instruction(sum_op{}, minus_pass3, two);
mm->add_instruction(pass_op{}, sum);
migraphx::module mm;
auto one = mm.add_literal(1);
auto two = mm.add_literal(2);
auto minus = mm.add_instruction(minus_op{}, two, one);
auto minus_pass1 = mm.add_instruction(pass_op{}, minus);
auto minus_pass2 = mm.add_instruction(pass_op{}, minus_pass1);
auto minus_pass3 = mm.add_instruction(pass_op{}, minus_pass2);
auto sum = mm.add_instruction(sum_op{}, minus_pass3, two);
mm.add_instruction(pass_op{}, sum);
auto m = match::name("minus")(match::skip_output(match::name("pass"))(match::name("sum")));
auto r = find_match(*mm, m);
auto r = find_match(mm, m);
EXPECT(bool{r.result == minus});
}
TEST_CASE(match_skip_output4)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto pass = mm->add_instruction(pass_op{}, one);
auto sum = mm->add_instruction(sum_op{}, pass, two);
mm->add_instruction(pass_op{}, sum);
migraphx::module mm;
auto one = mm.add_literal(1);
auto two = mm.add_literal(2);
auto pass = mm.add_instruction(pass_op{}, one);
auto sum = mm.add_instruction(sum_op{}, pass, two);
mm.add_instruction(pass_op{}, sum);
auto m = match::name("@literal")(match::skip_output(match::name("pass"))(match::name("sum")));
auto r = find_match(*mm, m);
auto r = find_match(mm, m);
EXPECT(bool{r.result == two});
}
TEST_CASE(match_skip_output5)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto pass = mm->add_instruction(pass_op{}, one);
auto sum1 = mm->add_instruction(sum_op{}, pass, two);
auto sum2 = mm->add_instruction(sum_op{}, sum1, one);
auto sum3 = mm->add_instruction(sum_op{}, sum2, two);
mm->add_instruction(pass_op{}, sum3);
migraphx::module mm;
auto one = mm.add_literal(1);
auto two = mm.add_literal(2);
auto pass = mm.add_instruction(pass_op{}, one);
auto sum1 = mm.add_instruction(sum_op{}, pass, two);
auto sum2 = mm.add_instruction(sum_op{}, sum1, one);
auto sum3 = mm.add_instruction(sum_op{}, sum2, two);
mm.add_instruction(pass_op{}, sum3);
auto m = match::name("@literal")(match::skip_output(match::name("pass"))(match::name("sum")));
auto r = find_match(*mm, m);
EXPECT(bool{r.result == mm->end()});
auto r = find_match(mm, m);
EXPECT(bool{r.result == mm.end()});
}
TEST_CASE(match_skip_output6)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto minus = mm->add_instruction(minus_op{}, two, one);
auto sum1 = mm->add_instruction(sum_op{}, minus, two);
auto sum2 = mm->add_instruction(sum_op{}, sum1, one);
auto sum3 = mm->add_instruction(sum_op{}, sum2, two);
mm->add_instruction(pass_op{}, sum3);
migraphx::module mm;
auto one = mm.add_literal(1);
auto two = mm.add_literal(2);
auto minus = mm.add_instruction(minus_op{}, two, one);
auto sum1 = mm.add_instruction(sum_op{}, minus, two);
auto sum2 = mm.add_instruction(sum_op{}, sum1, one);
auto sum3 = mm.add_instruction(sum_op{}, sum2, two);
mm.add_instruction(pass_op{}, sum3);
auto m = match::name("minus")(match::skip_output(match::name("pass"))(match::name("sum")));
auto r = find_match(*mm, m);
auto r = find_match(mm, m);
EXPECT(bool{r.result == minus});
}
TEST_CASE(match_skip_output7)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto minus1 = mm->add_instruction(minus_op{}, two, one);
auto minus2 = mm->add_instruction(minus_op{}, two, minus1);
auto sum = mm->add_instruction(sum_op{}, one, minus2);
mm->add_instruction(pass_op{}, sum);
migraphx::module mm;
auto one = mm.add_literal(1);
auto two = mm.add_literal(2);
auto minus1 = mm.add_instruction(minus_op{}, two, one);
auto minus2 = mm.add_instruction(minus_op{}, two, minus1);
auto sum = mm.add_instruction(sum_op{}, one, minus2);
mm.add_instruction(pass_op{}, sum);
auto m = match::name("minus")(match::skip_output(match::name("pass"))(match::name("minus")));
auto r = find_match(*mm, m);
auto r = find_match(mm, m);
EXPECT(bool{r.result == minus1});
}
TEST_CASE(match_bind1)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto sum = mm->add_instruction(sum_op{}, one, two);
auto pass = mm->add_instruction(pass_op{}, sum);
migraphx::module mm;
auto one = mm.add_literal(1);
auto two = mm.add_literal(2);
auto sum = mm.add_instruction(sum_op{}, one, two);
auto pass = mm.add_instruction(pass_op{}, sum);
auto m = match::name("pass")(
match::args(match::name("sum")(match::args(match::name("@literal").bind("one"),
match::name("@literal").bind("two")))
.bind("sum")),
match::standard_shape())
.bind("pass");
auto r = find_match(*mm, m);
auto r = find_match(mm, m);
EXPECT(bool{r.instructions.at("one") == one});
EXPECT(bool{r.instructions.at("two") == two});
EXPECT(bool{r.instructions.at("sum") == sum});
......@@ -877,265 +767,280 @@ TEST_CASE(match_bind1)
EXPECT(bool{r.result == pass});
}
TEST_CASE(match_has_value1)
TEST_CASE(match_bind_modules1)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto* child = p.create_module("child");
auto two = child->add_literal(2);
auto sum = child->add_instruction(sum_op{}, one, two);
child->add_instruction(pass_op{}, sum);
mm->add_instruction(mod_pass_op{}, {one}, {child});
auto m = match::name("pass")(
match::args(match::name("sum")(match::args(match::name("@literal").bind("one"),
match::name("@literal").bind("two")))
.bind("sum")),
match::standard_shape())
.bind("pass");
auto r = find_match(*child, m);
EXPECT(not migraphx::contains(r.instructions, "one"));
EXPECT(not migraphx::contains(r.instructions, "two"));
EXPECT(not migraphx::contains(r.instructions, "sum"));
EXPECT(not migraphx::contains(r.instructions, "pass"));
EXPECT(bool{r.result == child->end()});
}
TEST_CASE(match_bind_modules2)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto* child = p.create_module("child");
auto two = child->add_literal(2);
auto sum = child->add_instruction(sum_op{}, one, two);
auto pass = child->add_instruction(pass_op{}, sum);
mm->add_instruction(mod_pass_op{}, {one}, {child});
auto m = match::name("pass")(
match::args(match::name("sum")(match::args(match::name("@literal"),
match::name("@literal").bind("two")))
.bind("sum")),
match::standard_shape())
.bind("pass");
auto r = find_match(*child, m);
EXPECT(bool{r.instructions.at("two") == two});
EXPECT(bool{r.instructions.at("sum") == sum});
EXPECT(bool{r.instructions.at("pass") == pass});
EXPECT(bool{r.result == pass});
}
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto sum1 = mm->add_instruction(sum_op{}, one, two);
auto sum2 = mm->add_instruction(sum_op{}, sum1, two);
mm->add_instruction(pass_op{}, sum2);
TEST_CASE(match_has_value1)
{
migraphx::module mm;
auto one = mm.add_literal(1);
auto two = mm.add_literal(2);
auto sum1 = mm.add_instruction(sum_op{}, one, two);
auto sum2 = mm.add_instruction(sum_op{}, sum1, two);
mm.add_instruction(pass_op{}, sum2);
auto m = match::has_value(1);
auto r = find_match(*mm, m);
auto r = find_match(mm, m);
EXPECT(bool{r.result == one});
}
TEST_CASE(match_has_value2)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto sum1 = mm->add_instruction(sum_op{}, one, two);
auto sum2 = mm->add_instruction(sum_op{}, sum1, two);
mm->add_instruction(pass_op{}, sum2);
migraphx::module mm;
auto one = mm.add_literal(1);
auto two = mm.add_literal(2);
auto sum1 = mm.add_instruction(sum_op{}, one, two);
auto sum2 = mm.add_instruction(sum_op{}, sum1, two);
mm.add_instruction(pass_op{}, sum2);
auto m = match::has_value(2);
auto r = find_match(*mm, m);
auto r = find_match(mm, m);
EXPECT(bool{r.result == two});
}
TEST_CASE(match_has_value3)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto sum1 = mm->add_instruction(sum_op{}, one, two);
auto sum2 = mm->add_instruction(sum_op{}, sum1, two);
mm->add_instruction(pass_op{}, sum2);
migraphx::module mm;
auto one = mm.add_literal(1);
auto two = mm.add_literal(2);
auto sum1 = mm.add_instruction(sum_op{}, one, two);
auto sum2 = mm.add_instruction(sum_op{}, sum1, two);
mm.add_instruction(pass_op{}, sum2);
auto m = match::name("sum")(match::args(match::has_value(1), match::has_value(2)));
auto r = find_match(*mm, m);
auto r = find_match(mm, m);
EXPECT(bool{r.result == sum1});
}
TEST_CASE(match_has_value4)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto sum1 = mm->add_instruction(sum_op{}, one, two);
auto sum2 = mm->add_instruction(sum_op{}, sum1, two);
mm->add_instruction(pass_op{}, sum2);
migraphx::module mm;
auto one = mm.add_literal(1);
auto two = mm.add_literal(2);
auto sum1 = mm.add_instruction(sum_op{}, one, two);
auto sum2 = mm.add_instruction(sum_op{}, sum1, two);
mm.add_instruction(pass_op{}, sum2);
auto m = match::has_value(3);
auto r = find_match(*mm, m);
EXPECT(bool{r.result == mm->end()});
auto r = find_match(mm, m);
EXPECT(bool{r.result == mm.end()});
}
TEST_CASE(match_has_value5)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto sum1 = mm->add_instruction(sum_op{}, one, two);
auto sum2 = mm->add_instruction(sum_op{}, sum1, two);
mm->add_instruction(pass_op{}, sum2);
migraphx::module mm;
auto one = mm.add_literal(1);
auto two = mm.add_literal(2);
auto sum1 = mm.add_instruction(sum_op{}, one, two);
auto sum2 = mm.add_instruction(sum_op{}, sum1, two);
mm.add_instruction(pass_op{}, sum2);
auto m = match::name("sum")(match::args(match::has_value(1), match::has_value(3)));
auto r = find_match(*mm, m);
EXPECT(bool{r.result == mm->end()});
auto r = find_match(mm, m);
EXPECT(bool{r.result == mm.end()});
}
TEST_CASE(match_has_value6)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto sum1 = mm->add_instruction(sum_op{}, one, two);
auto sum2 = mm->add_instruction(sum_op{}, sum1, two);
mm->add_instruction(pass_op{}, sum2);
migraphx::module mm;
auto one = mm.add_literal(1);
auto two = mm.add_literal(2);
auto sum1 = mm.add_instruction(sum_op{}, one, two);
auto sum2 = mm.add_instruction(sum_op{}, sum1, two);
mm.add_instruction(pass_op{}, sum2);
auto m = match::name("sum")(match::args(match::has_value(2), match::has_value(1)));
auto r = find_match(*mm, m);
EXPECT(bool{r.result == mm->end()});
auto r = find_match(mm, m);
EXPECT(bool{r.result == mm.end()});
}
TEST_CASE(match_tree1)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto three = mm->add_literal(3);
auto sum1 = mm->add_instruction(sum_op{}, one, two);
auto sum2 = mm->add_instruction(sum_op{}, sum1, three);
mm->add_instruction(pass_op{}, sum2);
migraphx::module mm;
auto one = mm.add_literal(1);
auto two = mm.add_literal(2);
auto three = mm.add_literal(3);
auto sum1 = mm.add_instruction(sum_op{}, one, two);
auto sum2 = mm.add_instruction(sum_op{}, sum1, three);
mm.add_instruction(pass_op{}, sum2);
auto m = match::tree(
match::name("sum"), match::has_value(1), match::has_value(2), match::has_value(3));
auto r = find_match(*mm, m);
auto r = find_match(mm, m);
EXPECT(bool{r.result == sum2});
}
TEST_CASE(match_tree2)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto three = mm->add_literal(3);
auto sum1 = mm->add_instruction(sum_op{}, one, two);
auto sum2 = mm->add_instruction(sum_op{}, sum1, three);
mm->add_instruction(pass_op{}, sum2);
migraphx::module mm;
auto one = mm.add_literal(1);
auto two = mm.add_literal(2);
auto three = mm.add_literal(3);
auto sum1 = mm.add_instruction(sum_op{}, one, two);
auto sum2 = mm.add_instruction(sum_op{}, sum1, three);
mm.add_instruction(pass_op{}, sum2);
auto m = match::tree(
match::name("sum"), match::has_value(2), match::has_value(1), match::has_value(3));
auto r = find_match(*mm, m);
EXPECT(bool{r.result == mm->end()});
auto r = find_match(mm, m);
EXPECT(bool{r.result == mm.end()});
}
TEST_CASE(match_tree3)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto three = mm->add_literal(3);
auto sum1 = mm->add_instruction(sum_op{}, one, two);
auto sum2 = mm->add_instruction(sum_op{}, three, sum1);
mm->add_instruction(pass_op{}, sum2);
migraphx::module mm;
auto one = mm.add_literal(1);
auto two = mm.add_literal(2);
auto three = mm.add_literal(3);
auto sum1 = mm.add_instruction(sum_op{}, one, two);
auto sum2 = mm.add_instruction(sum_op{}, three, sum1);
mm.add_instruction(pass_op{}, sum2);
auto m = match::tree(
match::name("sum"), match::has_value(3), match::has_value(1), match::has_value(2));
auto r = find_match(*mm, m);
auto r = find_match(mm, m);
EXPECT(bool{r.result == sum2});
}
TEST_CASE(match_tree4)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto three = mm->add_literal(3);
auto sum1 = mm->add_instruction(sum_op{}, one, two);
auto sum2 = mm->add_instruction(sum_op{}, sum1, three);
mm->add_instruction(pass_op{}, sum2);
migraphx::module mm;
auto one = mm.add_literal(1);
auto two = mm.add_literal(2);
auto three = mm.add_literal(3);
auto sum1 = mm.add_instruction(sum_op{}, one, two);
auto sum2 = mm.add_instruction(sum_op{}, sum1, three);
mm.add_instruction(pass_op{}, sum2);
auto m = match::tree(match::name("sum"),
match::has_value(1),
match::has_value(2),
match::has_value(3),
match::has_value(4));
auto r = find_match(*mm, m);
EXPECT(bool{r.result == mm->end()});
auto r = find_match(mm, m);
EXPECT(bool{r.result == mm.end()});
}
TEST_CASE(match_tree5)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto three = mm->add_literal(3);
auto sum1 = mm->add_instruction(sum_op{}, one, two);
auto sum2 = mm->add_instruction(sum_op{}, sum1, three);
mm->add_instruction(pass_op{}, sum2);
migraphx::module mm;
auto one = mm.add_literal(1);
auto two = mm.add_literal(2);
auto three = mm.add_literal(3);
auto sum1 = mm.add_instruction(sum_op{}, one, two);
auto sum2 = mm.add_instruction(sum_op{}, sum1, three);
mm.add_instruction(pass_op{}, sum2);
auto m = match::tree(match::name("sum"), match::has_value(2), match::has_value(3));
auto r = find_match(*mm, m);
EXPECT(bool{r.result == mm->end()});
auto r = find_match(mm, m);
EXPECT(bool{r.result == mm.end()});
}
TEST_CASE(match_tree6)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto three = mm->add_literal(3);
auto sum1 = mm->add_instruction(sum_op{}, one, two);
auto sum2 = mm->add_instruction(sum_op{}, sum1, three);
mm->add_instruction(pass_op{}, sum2);
migraphx::module mm;
auto one = mm.add_literal(1);
auto two = mm.add_literal(2);
auto three = mm.add_literal(3);
auto sum1 = mm.add_instruction(sum_op{}, one, two);
auto sum2 = mm.add_instruction(sum_op{}, sum1, three);
mm.add_instruction(pass_op{}, sum2);
auto m = match::tree(match::name("sum"), match::has_value(1), match::has_value(3));
auto r = find_match(*mm, m);
EXPECT(bool{r.result == mm->end()});
auto r = find_match(mm, m);
EXPECT(bool{r.result == mm.end()});
}
TEST_CASE(match_unordered_tree1)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto three = mm->add_literal(3);
auto sum1 = mm->add_instruction(sum_op{}, one, two);
auto sum2 = mm->add_instruction(sum_op{}, sum1, three);
mm->add_instruction(pass_op{}, sum2);
migraphx::module mm;
auto one = mm.add_literal(1);
auto two = mm.add_literal(2);
auto three = mm.add_literal(3);
auto sum1 = mm.add_instruction(sum_op{}, one, two);
auto sum2 = mm.add_instruction(sum_op{}, sum1, three);
mm.add_instruction(pass_op{}, sum2);
auto m = match::unordered_tree(
match::name("sum"), match::has_value(3), match::has_value(2), match::has_value(1));
auto r = find_match(*mm, m);
auto r = find_match(mm, m);
EXPECT(bool{r.result == sum2});
}
TEST_CASE(match_unordered_tree2)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto three = mm->add_literal(3);
auto sum1 = mm->add_instruction(sum_op{}, one, two);
auto sum2 = mm->add_instruction(sum_op{}, three, sum1);
mm->add_instruction(pass_op{}, sum2);
migraphx::module mm;
auto one = mm.add_literal(1);
auto two = mm.add_literal(2);
auto three = mm.add_literal(3);
auto sum1 = mm.add_instruction(sum_op{}, one, two);
auto sum2 = mm.add_instruction(sum_op{}, three, sum1);
mm.add_instruction(pass_op{}, sum2);
auto m = match::unordered_tree(
match::name("sum"), match::has_value(3), match::has_value(2), match::has_value(1));
auto r = find_match(*mm, m);
auto r = find_match(mm, m);
EXPECT(bool{r.result == sum2});
}
TEST_CASE(match_unordered_tree3)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto three = mm->add_literal(3);
auto sum1 = mm->add_instruction(sum_op{}, two, one);
auto sum2 = mm->add_instruction(sum_op{}, sum1, three);
mm->add_instruction(pass_op{}, sum2);
migraphx::module mm;
auto one = mm.add_literal(1);
auto two = mm.add_literal(2);
auto three = mm.add_literal(3);
auto sum1 = mm.add_instruction(sum_op{}, two, one);
auto sum2 = mm.add_instruction(sum_op{}, sum1, three);
mm.add_instruction(pass_op{}, sum2);
auto m = match::unordered_tree(
match::name("sum"), match::has_value(3), match::has_value(2), match::has_value(1));
auto r = find_match(*mm, m);
auto r = find_match(mm, m);
EXPECT(bool{r.result == sum2});
}
TEST_CASE(match_unordered_tree4)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto three = mm->add_literal(3);
auto sum1 = mm->add_instruction(sum_op{}, one, two);
auto sum2 = mm->add_instruction(sum_op{}, sum1, three);
mm->add_instruction(pass_op{}, sum2);
migraphx::module mm;
auto one = mm.add_literal(1);
auto two = mm.add_literal(2);
auto three = mm.add_literal(3);
auto sum1 = mm.add_instruction(sum_op{}, one, two);
auto sum2 = mm.add_instruction(sum_op{}, sum1, three);
mm.add_instruction(pass_op{}, sum2);
auto m = match::unordered_tree(
match::name("sum"), match::has_value(4), match::has_value(2), match::has_value(1));
auto r = find_match(*mm, m);
EXPECT(bool{r.result == mm->end()});
auto r = find_match(mm, m);
EXPECT(bool{r.result == mm.end()});
}
struct match_find_sum
......@@ -1163,14 +1068,12 @@ struct match_find_literal
TEST_CASE(match_finder)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto sum = mm->add_instruction(sum_op{}, one, two);
mm->add_instruction(pass_op{}, sum);
match::find_matches(*mm, match_find_sum{sum}, match_find_literal{sum});
migraphx::module mm;
auto one = mm.add_literal(1);
auto two = mm.add_literal(2);
auto sum = mm.add_instruction(sum_op{}, one, two);
mm.add_instruction(pass_op{}, sum);
match::find_matches(mm, match_find_sum{sum}, match_find_literal{sum});
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
......@@ -24,33 +24,102 @@ migraphx::program create_program()
return p;
}
TEST_CASE(module_ins_clear)
TEST_CASE(calc_implict_deps)
{
migraphx::program p1 = create_program();
migraphx::program p2;
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape cond_s{migraphx::shape::bool_type};
migraphx::shape xs{migraphx::shape::float_type, {2, 3}};
migraphx::shape ys{migraphx::shape::float_type, {3, 3}};
std::vector<float> datax = {1, 2, 3, 4, 5, 6};
std::vector<float> datay = {8, 7, 6, 5, 4, 3, 2, 1, 0};
p2 = p1;
auto lx = mm->add_literal(migraphx::literal(xs, datax));
auto ly = mm->add_literal(migraphx::literal(ys, datay));
auto cond = mm->add_parameter("cond", cond_s);
auto x1 = mm->add_parameter("x1", xs);
auto x2 = mm->add_parameter("x2", xs);
auto y2 = mm->add_parameter("y2", ys);
EXPECT(p1 == p2);
auto* then_mod = p.create_module("If_5_if");
auto l1 = then_mod->add_literal(migraphx::literal(ys, datay));
auto a1 = then_mod->add_instruction(migraphx::make_op("add"), x1, lx);
then_mod->add_return({a1, l1});
auto* then_mod1 = p.create_module("If_6_if");
auto l11 = then_mod1->add_literal(migraphx::literal(ys, datay));
auto a11 = then_mod1->add_instruction(migraphx::make_op("add"), x2, lx);
then_mod1->add_return({a11, l11});
auto* else_mod1 = p.create_module("If_6_else");
auto l21 = else_mod1->add_literal(migraphx::literal(xs, datax));
auto a21 = else_mod1->add_instruction(migraphx::make_op("mul"), y2, ly);
else_mod1->add_return({l21, a21});
auto* else_mod = p.create_module("If_5_else");
auto l2 = else_mod->add_literal(migraphx::literal(ys, datay));
auto a2 = else_mod->add_instruction(migraphx::make_op("if"), {cond}, {then_mod1, else_mod1});
auto a3 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), a2);
else_mod->add_return({a3, l2});
auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod});
auto r = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), ret);
mm->add_return({r});
auto implicit_deps = mm->calc_implicit_deps();
EXPECT(migraphx::contains(implicit_deps, ret));
EXPECT(migraphx::contains(implicit_deps.at(ret), x1));
EXPECT(migraphx::contains(implicit_deps.at(ret), x2));
EXPECT(migraphx::contains(implicit_deps.at(ret), y2));
}
TEST_CASE(module_print_graph)
TEST_CASE(module_annotate)
{
migraphx::program p1 = create_program();
migraphx::program p2 = create_program();
auto* mm1 = p1.get_main_module();
auto* mm2 = p2.get_main_module();
EXPECT(*mm1 == *mm2);
std::stringstream ss1;
mm1->print_graph(ss1, true);
mm1->annotate(ss1, [](auto ins) { std::cout << ins->name() << "_1" << std::endl; });
std::stringstream ss2;
mm2->print_graph(ss2, true);
mm2->annotate(ss2, [](auto ins) { std::cout << ins->name() << "_1" << std::endl; });
EXPECT(ss1.str() == ss2.str());
}
TEST_CASE(module_ins_clear)
{
migraphx::program p1 = create_program();
migraphx::program p2;
p2 = p1;
EXPECT(p1 == p2);
}
TEST_CASE(module_name)
{
migraphx::module m1("name");
EXPECT(m1.name() == "name");
auto m2 = m1; // NOLINT
EXPECT(m2.name() == "name");
migraphx::module m3;
m3 = m1;
EXPECT(m3.name() == "name");
}
TEST_CASE(module_name_main)
{
migraphx::program p;
auto* mm = p.get_main_module();
EXPECT(mm->name() == "main");
}
TEST_CASE(module_print_cpp)
{
migraphx::program p1 = create_program();
......@@ -68,43 +137,23 @@ TEST_CASE(module_print_cpp)
EXPECT(ss1.str() == ss2.str());
}
TEST_CASE(module_annotate)
TEST_CASE(module_print_graph)
{
migraphx::program p1 = create_program();
migraphx::program p2 = create_program();
auto* mm1 = p1.get_main_module();
auto* mm2 = p2.get_main_module();
EXPECT(*mm1 == *mm2);
std::stringstream ss1;
mm1->annotate(ss1, [](auto ins) { std::cout << ins->name() << "_1" << std::endl; });
mm1->print_graph(ss1, true);
std::stringstream ss2;
mm2->annotate(ss2, [](auto ins) { std::cout << ins->name() << "_1" << std::endl; });
mm2->print_graph(ss2, true);
EXPECT(ss1.str() == ss2.str());
}
TEST_CASE(module_name)
{
migraphx::module m1("name");
EXPECT(m1.name() == "name");
auto m2 = m1; // NOLINT
EXPECT(m2.name() == "name");
migraphx::module m3;
m3 = m1;
EXPECT(m3.name() == "name");
}
TEST_CASE(module_name_main)
{
migraphx::program p;
auto* mm = p.get_main_module();
EXPECT(mm->name() == "main");
}
TEST_CASE(program_module_assign)
{
migraphx::program p;
......@@ -204,51 +253,4 @@ TEST_CASE(submodule_copy)
EXPECT(mm.get_sub_modules() == mm2.get_sub_modules());
}
TEST_CASE(calc_implict_deps)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape cond_s{migraphx::shape::bool_type};
migraphx::shape xs{migraphx::shape::float_type, {2, 3}};
migraphx::shape ys{migraphx::shape::float_type, {3, 3}};
std::vector<float> datax = {1, 2, 3, 4, 5, 6};
std::vector<float> datay = {8, 7, 6, 5, 4, 3, 2, 1, 0};
auto lx = mm->add_literal(migraphx::literal(xs, datax));
auto ly = mm->add_literal(migraphx::literal(ys, datay));
auto cond = mm->add_parameter("cond", cond_s);
auto x1 = mm->add_parameter("x1", xs);
auto x2 = mm->add_parameter("x2", xs);
auto y2 = mm->add_parameter("y2", ys);
auto* then_mod = p.create_module("If_5_if");
auto l1 = then_mod->add_literal(migraphx::literal(ys, datay));
auto a1 = then_mod->add_instruction(migraphx::make_op("add"), x1, lx);
then_mod->add_return({a1, l1});
auto* then_mod1 = p.create_module("If_6_if");
auto l11 = then_mod1->add_literal(migraphx::literal(ys, datay));
auto a11 = then_mod1->add_instruction(migraphx::make_op("add"), x2, lx);
then_mod1->add_return({a11, l11});
auto* else_mod1 = p.create_module("If_6_else");
auto l21 = else_mod1->add_literal(migraphx::literal(xs, datax));
auto a21 = else_mod1->add_instruction(migraphx::make_op("mul"), y2, ly);
else_mod1->add_return({l21, a21});
auto* else_mod = p.create_module("If_5_else");
auto l2 = else_mod->add_literal(migraphx::literal(ys, datay));
auto a2 = else_mod->add_instruction(migraphx::make_op("if"), {cond}, {then_mod1, else_mod1});
else_mod->add_return({a2, l2});
auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod});
mm->add_return({ret});
auto implicit_deps = mm->calc_implicit_deps();
EXPECT(migraphx::contains(implicit_deps, ret));
EXPECT(migraphx::contains(implicit_deps.at(ret), x1));
EXPECT(migraphx::contains(implicit_deps.at(ret), x2));
EXPECT(migraphx::contains(implicit_deps.at(ret), y2));
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
......@@ -75,6 +75,26 @@ TEST_CASE(gather_test_1)
EXPECT(m1 == m2);
}
migraphx::module create_padded_op(const std::vector<size_t>& pad_vals)
{
migraphx::module m;
migraphx::shape s{migraphx::shape::float_type, {2, 3, 4, 5}};
auto si = m.add_parameter("data", s);
auto r = m.add_instruction(migraphx::make_op("pooling", {{"padding", pad_vals}}), si);
m.add_return({r});
return m;
}
TEST_CASE(padding_attr_test)
{
migraphx::module m1 = create_padded_op({0, 1});
migraphx::module m2 = create_padded_op({0, 1, 0, 1});
run_pass(m1);
EXPECT(m1 == m2);
}
migraphx::module create_reduce_mean(const std::vector<int64_t>& axes)
{
migraphx::module m;
......
......@@ -1900,6 +1900,77 @@ def if_then_test():
return ([node], [x, y], [res], [cond_tensor, xt_tensor, yt_tensor])
@onnx_test
def if_tuple_test():
x = onnx.helper.make_tensor_value_info('x', onnx.TensorProto.FLOAT, [1, 4])
y = onnx.helper.make_tensor_value_info('y', onnx.TensorProto.FLOAT, [3, 4])
cond_input = onnx.helper.make_tensor_value_info('cond',
onnx.TensorProto.BOOL, [])
then_out0 = onnx.helper.make_tensor_value_info('then_out0',
onnx.TensorProto.FLOAT,
[1, 4])
then_out1 = onnx.helper.make_tensor_value_info('then_out1',
onnx.TensorProto.FLOAT,
[3, 4])
else_out0 = onnx.helper.make_tensor_value_info('else_out0',
onnx.TensorProto.FLOAT,
[1, 4])
else_out1 = onnx.helper.make_tensor_value_info('else_out1',
onnx.TensorProto.FLOAT,
[3, 4])
one = np.ones([1]).astype(np.float)
one_tensor = helper.make_tensor(name='one',
data_type=TensorProto.FLOAT,
dims=one.shape,
vals=one.flatten().astype(np.float32))
two = np.array([2]).astype(np.float)
two_tensor = helper.make_tensor(name='two',
data_type=TensorProto.FLOAT,
dims=two.shape,
vals=two.flatten().astype(np.float32))
three = np.array([3]).astype(np.float)
three_tensor = helper.make_tensor(name='three',
data_type=TensorProto.FLOAT,
dims=three.shape,
vals=three.flatten().astype(np.float32))
then_add_node = onnx.helper.make_node('Add',
inputs=['x', 'one'],
outputs=['then_out0'])
then_mul_node = onnx.helper.make_node('Mul',
inputs=['y', 'two'],
outputs=['then_out1'])
else_mul_node = onnx.helper.make_node('Mul',
inputs=['x', 'three'],
outputs=['else_out0'])
else_add_node = onnx.helper.make_node('Add',
inputs=['y', 'three'],
outputs=['else_out1'])
then_body = onnx.helper.make_graph([then_add_node, then_mul_node],
'then_body', [], [then_out0, then_out1])
else_body = onnx.helper.make_graph([else_mul_node, else_add_node],
'else_body', [], [else_out0, else_out1])
res0 = onnx.helper.make_tensor_value_info('res0', TensorProto.FLOAT, [])
res1 = onnx.helper.make_tensor_value_info('res1', TensorProto.FLOAT, [])
node = onnx.helper.make_node('If',
inputs=['cond'],
outputs=['res0', 'res1'],
then_branch=then_body,
else_branch=else_body)
return ([node], [cond_input, x,
y], [res0, res1], [one_tensor, two_tensor, three_tensor])
@onnx_test
def imagescaler_test():
x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [1, 3, 16, 16])
......@@ -2218,7 +2289,10 @@ def logsoftmax_nonstd_input_test():
ends=[4, 4],
outputs=['1'])
node1 = onnx.helper.make_node('LogSoftmax', inputs=['1'], outputs=['2'])
node1 = onnx.helper.make_node('LogSoftmax',
inputs=['1'],
outputs=['2'],
axis=-1)
return ([node0, node1], [x], [z])
......@@ -3099,7 +3173,7 @@ def resize_downsample_f_test():
vals=scales.flatten().astype(np.float32))
X = helper.make_tensor_value_info('X', TensorProto.FLOAT, [1, 1, 2, 4])
Y = helper.make_tensor_value_info('Y', TensorProto.FLOAT, [1, 1, 1, 2])
Y = helper.make_tensor_value_info('Y', TensorProto.FLOAT, [])
node = onnx.helper.make_node(
'Resize',
......@@ -3133,6 +3207,25 @@ def resize_downsample_c_test():
return ([node], [X], [Y], [scale_tensor])
@onnx_test
def resize_downsample_linear_test():
scales = np.array([1.0, 1.0, 0.6, 0.5], dtype=np.float32)
scale_tensor = helper.make_tensor(name='scales',
data_type=TensorProto.FLOAT,
dims=scales.shape,
vals=scales.flatten().astype(np.float32))
X = helper.make_tensor_value_info('X', TensorProto.FLOAT, [1, 1, 2, 4])
Y = helper.make_tensor_value_info('Y', TensorProto.FLOAT, [])
node = onnx.helper.make_node('Resize',
inputs=['X', '', 'scales'],
outputs=['Y'],
mode='linear')
return ([node], [X], [Y], [scale_tensor])
@onnx_test
def resize_nonstd_input_test():
scales = np.array([1.0, 1.0, 0.6, 0.6], dtype=np.float32)
......@@ -3182,6 +3275,46 @@ def resize_outsize_test():
return ([node], [X], [Y], [out_lens_tensor])
@onnx_test
def resize_upsample_linear_ac_test():
scales = np.array([1.0, 1.0, 2.0, 2.0], dtype=np.float32)
scales_tensor = helper.make_tensor(name='scales',
data_type=TensorProto.FLOAT,
dims=scales.shape,
vals=scales.flatten().astype(
np.float32))
X = helper.make_tensor_value_info('X', TensorProto.FLOAT, [1, 1, 2, 2])
Y = helper.make_tensor_value_info('Y', TensorProto.FLOAT, [])
node = onnx.helper.make_node(
'Resize',
inputs=['X', '', 'scales'],
outputs=['Y'],
mode='linear',
coordinate_transformation_mode='align_corners')
return ([node], [X], [Y], [scales_tensor])
@onnx_test
def resize_upsample_linear_test():
scales = np.array([1.0, 1.0, 2.0, 2.0], dtype=np.float32)
scales_tensor = helper.make_tensor(name='scales',
data_type=TensorProto.FLOAT,
dims=scales.shape,
vals=scales.flatten().astype(
np.float32))
X = helper.make_tensor_value_info('X', TensorProto.FLOAT, [1, 1, 2, 2])
Y = helper.make_tensor_value_info('Y', TensorProto.FLOAT, [])
node = onnx.helper.make_node('Resize',
inputs=['X', '', 'scales'],
outputs=['Y'],
mode='linear')
return ([node], [X], [Y], [scales_tensor])
@onnx_test
def resize_upsample_pf_test():
scales = np.array([1.0, 1.0, 2.0, 3.0], dtype=np.float32)
......@@ -3429,6 +3562,112 @@ def slice_5arg_test():
return ([arg_step, arg_axis, arg_end, arg_start, node], [x], [y])
@onnx_test
def slice_5arg_reverse_test():
step = np.array([-1, 1])
step_tensor = helper.make_tensor(name="step",
data_type=TensorProto.INT32,
dims=step.shape,
vals=step.astype(int))
arg_step = helper.make_node("Constant",
inputs=[],
outputs=['arg_step'],
value=step_tensor)
axis = np.array([-1, -2])
axis_tensor = helper.make_tensor(name="axis",
data_type=TensorProto.INT32,
dims=axis.shape,
vals=axis.astype(int))
arg_axis = helper.make_node("Constant",
inputs=[],
outputs=['arg_axis'],
value=axis_tensor)
end = np.array([-5, -1])
end_tensor = helper.make_tensor(name="end",
data_type=TensorProto.INT32,
dims=end.shape,
vals=end.astype(int))
arg_end = helper.make_node("Constant",
inputs=[],
outputs=['arg_end'],
value=end_tensor)
start = np.array([-1, -3])
start_tensor = helper.make_tensor(name="start",
data_type=TensorProto.INT32,
dims=start.shape,
vals=start.astype(int))
arg_start = helper.make_node("Constant",
inputs=[],
outputs=['arg_start'],
value=start_tensor)
x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [5, 5])
y = helper.make_tensor_value_info('1', TensorProto.FLOAT, [4, 2])
node = onnx.helper.make_node(
'Slice',
inputs=['0', 'arg_start', 'arg_end', 'arg_axis', 'arg_step'],
outputs=['1'])
return ([arg_step, arg_axis, arg_end, arg_start, node], [x], [y])
@onnx_test
def slice_5arg_step_test():
step = np.array([-2, 2])
step_tensor = helper.make_tensor(name="step",
data_type=TensorProto.INT32,
dims=step.shape,
vals=step.astype(int))
arg_step = helper.make_node("Constant",
inputs=[],
outputs=['arg_step'],
value=step_tensor)
axis = np.array([-1, -2])
axis_tensor = helper.make_tensor(name="axis",
data_type=TensorProto.INT32,
dims=axis.shape,
vals=axis.astype(int))
arg_axis = helper.make_node("Constant",
inputs=[],
outputs=['arg_axis'],
value=axis_tensor)
end = np.array([-5, -1])
end_tensor = helper.make_tensor(name="end",
data_type=TensorProto.INT32,
dims=end.shape,
vals=end.astype(int))
arg_end = helper.make_node("Constant",
inputs=[],
outputs=['arg_end'],
value=end_tensor)
start = np.array([-1, -3])
start_tensor = helper.make_tensor(name="start",
data_type=TensorProto.INT32,
dims=start.shape,
vals=start.astype(int))
arg_start = helper.make_node("Constant",
inputs=[],
outputs=['arg_start'],
value=start_tensor)
x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [5, 5])
y = helper.make_tensor_value_info('1', TensorProto.FLOAT, [4, 2])
node = onnx.helper.make_node(
'Slice',
inputs=['0', 'arg_start', 'arg_end', 'arg_axis', 'arg_step'],
outputs=['1'])
return ([arg_step, arg_axis, arg_end, arg_start, node], [x], [y])
@onnx_test
def slice_max_end_test():
x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [10, 20])
......
......@@ -182,7 +182,8 @@ TEST_CASE(averagepool_1d_test)
auto l0 = mm->add_parameter("0", {migraphx::shape::float_type, {1, 3, 5}});
mm->add_instruction(
migraphx::make_op(
"pooling", {{"mode", "average"}, {"padding", {0}}, {"stride", {1}}, {"lengths", {3}}}),
"pooling",
{{"mode", "average"}, {"padding", {0, 0}}, {"stride", {1}}, {"lengths", {3}}}),
l0);
auto prog = optimize_onnx("averagepool_1d_test.onnx");
......@@ -196,7 +197,7 @@ TEST_CASE(averagepool_3d_test)
auto l0 = mm->add_parameter("0", {migraphx::shape::float_type, {1, 3, 5, 5, 5}});
mm->add_instruction(migraphx::make_op("pooling",
{{"mode", "average"},
{"padding", {0, 0, 0}},
{"padding", {0, 0, 0, 0, 0, 0}},
{"stride", {1, 1, 1}},
{"lengths", {3, 3, 3}}}),
l0);
......@@ -210,12 +211,13 @@ TEST_CASE(averagepool_notset_test)
migraphx::program p;
auto* mm = p.get_main_module();
auto input = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 5}});
auto ins = mm->add_instruction(
migraphx::make_op(
"pooling",
{{"mode", "average"}, {"padding", {2, 2}}, {"stride", {2, 2}}, {"lengths", {6, 6}}}),
input);
auto ret = mm->add_instruction(
auto ins = mm->add_instruction(migraphx::make_op("pooling",
{{"mode", "average"},
{"padding", {2, 2, 2, 2}},
{"stride", {2, 2}},
{"lengths", {6, 6}}}),
input);
auto ret = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {2, 3}}, {"starts", {1, 1}}, {"ends", {2, 2}}}), ins);
mm->add_return({ret});
auto prog = migraphx::parse_onnx("averagepool_notset_test.onnx");
......@@ -230,11 +232,12 @@ TEST_CASE(averagepool_nt_cip_test)
auto input = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 5}});
std::vector<int64_t> pads = {0, 0, 0, 0, 0, 0, 1, 1};
auto ins_pad = mm->add_instruction(migraphx::make_op("pad", {{"pads", pads}}), input);
auto ret = mm->add_instruction(
migraphx::make_op(
"pooling",
{{"mode", "average"}, {"padding", {0, 0}}, {"stride", {2, 2}}, {"lengths", {6, 6}}}),
ins_pad);
auto ret = mm->add_instruction(migraphx::make_op("pooling",
{{"mode", "average"},
{"padding", {0, 0, 0, 0}},
{"stride", {2, 2}},
{"lengths", {6, 6}}}),
ins_pad);
mm->add_return({ret});
auto prog = migraphx::parse_onnx("averagepool_nt_cip_test.onnx");
......@@ -246,12 +249,13 @@ TEST_CASE(averagepool_same_lower_test)
migraphx::program p;
auto* mm = p.get_main_module();
auto input = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 5}});
auto ins = mm->add_instruction(
migraphx::make_op(
"pooling",
{{"mode", "average"}, {"padding", {1, 1}}, {"stride", {1, 1}}, {"lengths", {2, 2}}}),
input);
auto ret = mm->add_instruction(
auto ins = mm->add_instruction(migraphx::make_op("pooling",
{{"mode", "average"},
{"padding", {1, 1, 1, 1}},
{"stride", {1, 1}},
{"lengths", {2, 2}}}),
input);
auto ret = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {2, 3}}, {"starts", {0, 0}}, {"ends", {5, 5}}}), ins);
mm->add_return({ret});
auto prog = migraphx::parse_onnx("averagepool_same_lower_test.onnx");
......@@ -266,11 +270,12 @@ TEST_CASE(averagepool_sl_cip_test)
auto input = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 5}});
std::vector<int64_t> pads = {0, 0, 1, 1, 0, 0, 0, 0};
auto ins_pad = mm->add_instruction(migraphx::make_op("pad", {{"pads", pads}}), input);
auto ret = mm->add_instruction(
migraphx::make_op(
"pooling",
{{"mode", "average"}, {"padding", {0, 0}}, {"stride", {1, 1}}, {"lengths", {2, 2}}}),
ins_pad);
auto ret = mm->add_instruction(migraphx::make_op("pooling",
{{"mode", "average"},
{"padding", {0, 0, 0, 0}},
{"stride", {1, 1}},
{"lengths", {2, 2}}}),
ins_pad);
mm->add_return({ret});
auto prog = migraphx::parse_onnx("averagepool_sl_cip_test.onnx");
......@@ -282,12 +287,13 @@ TEST_CASE(averagepool_same_upper_test)
migraphx::program p;
auto* mm = p.get_main_module();
auto input = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 5}});
auto ins = mm->add_instruction(
migraphx::make_op(
"pooling",
{{"mode", "average"}, {"padding", {1, 1}}, {"stride", {1, 1}}, {"lengths", {2, 2}}}),
input);
auto ret = mm->add_instruction(
auto ins = mm->add_instruction(migraphx::make_op("pooling",
{{"mode", "average"},
{"padding", {1, 1, 1, 1}},
{"stride", {1, 1}},
{"lengths", {2, 2}}}),
input);
auto ret = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {2, 3}}, {"starts", {1, 1}}, {"ends", {6, 6}}}), ins);
mm->add_return({ret});
auto prog = migraphx::parse_onnx("averagepool_same_upper_test.onnx");
......@@ -606,7 +612,7 @@ TEST_CASE(conv_autopad_same_test)
auto l0 = mm->add_parameter("0", {migraphx::shape::float_type, {1, 3, 32, 32}});
auto l1 = mm->add_parameter("1", {migraphx::shape::float_type, {1, 3, 3, 3}});
migraphx::op::convolution op;
op.padding = {1, 1};
op.padding = {1, 1, 1, 1};
op.padding_mode = migraphx::op::padding_mode_t::same;
mm->add_instruction(op, l0, l1);
......@@ -644,8 +650,9 @@ TEST_CASE(conv_bn_relu_maxpool_test)
auto p5 = mm->add_parameter("5", {migraphx::shape::float_type, {1}});
auto p6 = mm->add_parameter("6", {migraphx::shape::float_type, {1}});
uint64_t axis = 1;
auto l3 = mm->add_instruction(migraphx::make_op("convolution"), l0, l1);
auto l4 = mm->add_instruction(
auto l3 =
mm->add_instruction(migraphx::make_op("convolution", {{"padding", {0, 0, 0, 0}}}), l0, l1);
auto l4 = mm->add_instruction(
migraphx::make_op("broadcast", {{"axis", axis}, {"dims", l3->get_shape().lens()}}), l2);
auto l5 = mm->add_instruction(migraphx::make_op("add"), l3, l4);
auto l6 = mm->add_instruction(
......@@ -654,7 +661,7 @@ TEST_CASE(conv_bn_relu_maxpool_test)
mm->add_instruction(
migraphx::make_op(
"pooling",
{{"mode", "max"}, {"padding", {0, 0}}, {"stride", {2, 2}}, {"lengths", {2, 2}}}),
{{"mode", "max"}, {"padding", {0, 0, 0, 0}}, {"stride", {2, 2}}, {"lengths", {2, 2}}}),
l7);
auto prog = optimize_onnx("conv_bn_relu_maxpool_test.onnx");
......@@ -669,15 +676,16 @@ TEST_CASE(conv_relu_maxpool_test)
auto l1 = mm->add_parameter("1", {migraphx::shape::float_type, {1, 3, 5, 5}});
auto l2 = mm->add_parameter("2", {migraphx::shape::float_type, {1}});
uint64_t axis = 1;
auto l3 = mm->add_instruction(migraphx::make_op("convolution"), l0, l1);
auto l4 = mm->add_instruction(
auto l3 =
mm->add_instruction(migraphx::make_op("convolution", {{"padding", {0, 0, 0, 0}}}), l0, l1);
auto l4 = mm->add_instruction(
migraphx::make_op("broadcast", {{"axis", axis}, {"dims", l3->get_shape().lens()}}), l2);
auto l5 = mm->add_instruction(migraphx::make_op("add"), l3, l4);
auto l6 = mm->add_instruction(migraphx::make_op("relu"), l5);
mm->add_instruction(
migraphx::make_op(
"pooling",
{{"mode", "max"}, {"padding", {0, 0}}, {"stride", {2, 2}}, {"lengths", {2, 2}}}),
{{"mode", "max"}, {"padding", {0, 0, 0, 0}}, {"stride", {2, 2}}, {"lengths", {2, 2}}}),
l6);
auto prog = optimize_onnx("conv_relu_maxpool_test.onnx");
......@@ -692,20 +700,22 @@ TEST_CASE(conv_relu_maxpool_x2_test)
auto l1 = mm->add_parameter("1", {migraphx::shape::float_type, {5, 3, 5, 5}});
auto l2 = mm->add_parameter("2", {migraphx::shape::float_type, {5}});
uint64_t axis = 1;
auto l3 = mm->add_instruction(migraphx::make_op("convolution"), l0, l1);
auto l4 = mm->add_instruction(
auto l3 =
mm->add_instruction(migraphx::make_op("convolution", {{"padding", {0, 0, 0, 0}}}), l0, l1);
auto l4 = mm->add_instruction(
migraphx::make_op("broadcast", {{"axis", axis}, {"dims", l3->get_shape().lens()}}), l2);
auto l5 = mm->add_instruction(migraphx::make_op("add"), l3, l4);
auto l6 = mm->add_instruction(migraphx::make_op("relu"), l5);
auto l7 = mm->add_instruction(
migraphx::make_op(
"pooling",
{{"mode", "max"}, {"padding", {0, 0}}, {"stride", {2, 2}}, {"lengths", {2, 2}}}),
{{"mode", "max"}, {"padding", {0, 0, 0, 0}}, {"stride", {2, 2}}, {"lengths", {2, 2}}}),
l6);
auto l8 = mm->add_parameter("3", {migraphx::shape::float_type, {1, 5, 5, 5}});
auto l9 = mm->add_parameter("4", {migraphx::shape::float_type, {1}});
auto l10 = mm->add_instruction(migraphx::make_op("convolution"), l7, l8);
auto l8 = mm->add_parameter("3", {migraphx::shape::float_type, {1, 5, 5, 5}});
auto l9 = mm->add_parameter("4", {migraphx::shape::float_type, {1}});
auto l10 =
mm->add_instruction(migraphx::make_op("convolution", {{"padding", {0, 0, 0, 0}}}), l7, l8);
auto l11 = mm->add_instruction(
migraphx::make_op("broadcast", {{"axis", axis}, {"dims", l10->get_shape().lens()}}), l9);
auto l12 = mm->add_instruction(migraphx::make_op("add"), l10, l11);
......@@ -713,7 +723,7 @@ TEST_CASE(conv_relu_maxpool_x2_test)
mm->add_instruction(
migraphx::make_op(
"pooling",
{{"mode", "max"}, {"padding", {0, 0}}, {"stride", {2, 2}}, {"lengths", {2, 2}}}),
{{"mode", "max"}, {"padding", {0, 0, 0, 0}}, {"stride", {2, 2}}, {"lengths", {2, 2}}}),
l13);
auto prog = optimize_onnx("conv_relu_maxpool_x2_test.onnx");
......@@ -825,7 +835,8 @@ TEST_CASE(deconv_input_pads_asymm_1d_test)
auto l0 = mm->add_parameter("x", {migraphx::shape::float_type, {1, 1, 3}});
auto l1 = mm->add_parameter("w", {migraphx::shape::float_type, {1, 2, 3}});
auto l2 = mm->add_instruction(
migraphx::make_op("deconvolution", {{"padding", {0}}, {"stride", {2}}, {"dilation", {1}}}),
migraphx::make_op("deconvolution",
{{"padding", {0, 0}}, {"stride", {2}}, {"dilation", {1}}}),
l0,
l1);
mm->add_instruction(migraphx::make_op("slice", {{"axes", {2}}, {"starts", {0}}, {"ends", {6}}}),
......@@ -1117,9 +1128,10 @@ migraphx::program create_external_data_prog()
std::vector<float> weight_data(1210, 1);
std::vector<float> bias_data(10, 1);
auto bias = mm->add_literal(migraphx::literal({migraphx::shape::float_type, {10}}, bias_data));
auto weights = mm->add_literal(migraphx::literal(s2, weight_data));
auto param = mm->add_parameter("input", s);
auto conv = mm->add_instruction(migraphx::make_op("convolution"), param, weights);
auto weights = mm->add_literal(migraphx::literal(s2, weight_data));
auto param = mm->add_parameter("input", s);
auto conv = mm->add_instruction(
migraphx::make_op("convolution", {{"padding", {0, 0, 0, 0}}}), param, weights);
auto bias_bcast = mm->add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"dims", {1, 10, 214, 214}}}), bias);
mm->add_instruction(migraphx::make_op("add"), conv, bias_bcast);
......@@ -1300,6 +1312,7 @@ TEST_CASE(globalavgpool_test)
auto op = migraphx::op::pooling{"average"};
auto lens = input->get_shape().lens();
op.lengths = {lens[2], lens[3]};
op.padding = {0, 0, 0, 0};
mm->add_instruction(op, input);
auto prog = optimize_onnx("globalavgpool_test.onnx");
......@@ -1316,6 +1329,7 @@ TEST_CASE(globalmaxpool_test)
auto op = migraphx::op::pooling{"max"};
auto lens = input->get_shape().lens();
op.lengths = {lens[2], lens[3]};
op.padding = {0, 0, 0, 0};
mm->add_instruction(op, input);
auto prog = optimize_onnx("globalmaxpool_test.onnx");
......@@ -1382,17 +1396,25 @@ TEST_CASE(if_else_test)
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sc{migraphx::shape::bool_type, {1}};
mm->add_literal(migraphx::literal(sc, {0}));
auto cond = mm->add_literal(migraphx::literal(sc, {0}));
migraphx::shape s{migraphx::shape::float_type, {2, 3}};
std::vector<float> ones(s.elements(), 1.0f);
mm->add_literal(s, ones);
auto l1 = mm->add_literal(s, ones);
std::vector<float> rand = {-0.583375, 0.633757, 0.0668345, -0.479422, -0.604634, 0.0388589};
auto l2 = mm->add_literal(s, rand);
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
auto* then_mod = p.create_module("If_5_if");
auto rt = then_mod->add_instruction(migraphx::make_op("add"), x, l1);
then_mod->add_return({rt});
mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
auto* else_mod = p.create_module("If_5_else");
auto re = else_mod->add_instruction(migraphx::make_op("mul"), y, l2);
else_mod->add_return({re});
auto r = mm->add_instruction(migraphx::make_op("mul"), y, l2);
auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod});
auto r = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), ret);
mm->add_return({r});
std::ifstream ifs("if_else_test.onnx", std::ios::binary);
......@@ -1404,7 +1426,6 @@ TEST_CASE(if_else_test)
ifs.close();
auto prog = migraphx::parse_onnx_buffer(onnx_buffer.data(), length, {});
EXPECT(p == prog);
}
......@@ -1430,7 +1451,8 @@ TEST_CASE(if_literal_test)
else_mod->add_return({l2});
auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod});
mm->add_return({ret});
auto r = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), ret);
mm->add_return({r});
auto prog = migraphx::parse_onnx("if_literal_test.onnx");
EXPECT(p == prog);
......@@ -1469,7 +1491,8 @@ TEST_CASE(if_param_test)
else_mod->add_return({a2});
auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod});
mm->add_return({ret});
auto r = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), ret);
mm->add_return({r});
auto prog = migraphx::parse_onnx("if_param_test.onnx");
EXPECT(p == prog);
......@@ -1502,7 +1525,9 @@ TEST_CASE(if_pl_test)
else_mod->add_return({l2, a2});
auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod});
mm->add_return({ret});
auto r = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), ret);
mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), ret);
mm->add_return({r});
auto prog = migraphx::parse_onnx("if_pl_test.onnx");
EXPECT(p == prog);
......@@ -1513,21 +1538,70 @@ TEST_CASE(if_then_test)
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sc{migraphx::shape::bool_type, {1}};
mm->add_literal(migraphx::literal(sc, {1}));
auto cond = mm->add_literal(migraphx::literal(sc, {1}));
migraphx::shape s{migraphx::shape::float_type, {2, 3}};
std::vector<float> ones(s.elements(), 1.0f);
auto l1 = mm->add_literal(s, ones);
std::vector<float> rand = {-1.26487, -2.42279, 0.990835, 1.63072, 0.812238, -0.174946};
mm->add_literal(s, rand);
auto l2 = mm->add_literal(s, rand);
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
auto x = mm->add_parameter("x", s);
mm->add_parameter("y", s);
auto* then_mod = p.create_module("If_5_if");
auto rt = then_mod->add_instruction(migraphx::make_op("add"), x, l1);
then_mod->add_return({rt});
auto r = mm->add_instruction(migraphx::make_op("add"), x, l1);
auto* else_mod = p.create_module("If_5_else");
auto re = else_mod->add_instruction(migraphx::make_op("mul"), y, l2);
else_mod->add_return({re});
auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod});
auto r = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), ret);
mm->add_return({r});
auto prog = migraphx::parse_onnx("if_then_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(if_tuple_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sd{migraphx::shape::float_type, {1}};
auto l1 = mm->add_literal(migraphx::literal(sd, {1}));
auto l2 = mm->add_literal(migraphx::literal(sd, {2}));
auto l3 = mm->add_literal(migraphx::literal(sd, {3}));
migraphx::shape sx{migraphx::shape::float_type, {1, 4}};
migraphx::shape sy{migraphx::shape::float_type, {3, 4}};
migraphx::shape sc{migraphx::shape::bool_type};
auto cond = mm->add_parameter("cond", sc);
auto x = mm->add_parameter("x", sx);
auto y = mm->add_parameter("y", sy);
auto* then_mod = p.create_module("If_6_if");
auto m1 = then_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", {1, 4}}}), l1);
auto add0 = then_mod->add_instruction(migraphx::make_op("add"), x, m1);
auto m2 = then_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", {3, 4}}}), l2);
auto mul0 = then_mod->add_instruction(migraphx::make_op("mul"), y, m2);
then_mod->add_return({add0, mul0});
auto* else_mod = p.create_module("If_6_else");
auto me1 = else_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", {1, 4}}}), l3);
auto mul1 = else_mod->add_instruction(migraphx::make_op("mul"), x, me1);
auto me2 = else_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", {3, 4}}}), l3);
auto add1 = else_mod->add_instruction(migraphx::make_op("add"), y, me2);
else_mod->add_return({mul1, add1});
auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod});
auto r0 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), ret);
auto r1 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), ret);
mm->add_return({r0, r1});
auto prog = migraphx::parse_onnx("if_tuple_test.onnx");
EXPECT(p == prog);
}
......@@ -1995,15 +2069,11 @@ TEST_CASE(maxpool_notset_test)
migraphx::program p;
auto* mm = p.get_main_module();
auto input = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 5}});
std::vector<int64_t> pads = {0, 0, 0, 0, 0, 0, 1, 1};
float val = std::numeric_limits<float>::lowest();
auto ins_pad =
mm->add_instruction(migraphx::make_op("pad", {{"pads", pads}, {"value", val}}), input);
mm->add_instruction(
migraphx::make_op(
"pooling",
{{"mode", "max"}, {"padding", {0, 0}}, {"stride", {2, 2}}, {"lengths", {6, 6}}}),
ins_pad);
{{"mode", "max"}, {"padding", {0, 0, 1, 1}}, {"stride", {2, 2}}, {"lengths", {6, 6}}}),
input);
auto prog = optimize_onnx("maxpool_notset_test.onnx");
......@@ -2015,15 +2085,11 @@ TEST_CASE(maxpool_same_upper_test)
migraphx::program p;
auto* mm = p.get_main_module();
auto input = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 5}});
std::vector<int64_t> pads = {0, 0, 0, 0, 0, 0, 1, 1};
float val = std::numeric_limits<float>::lowest();
auto ins_pad =
mm->add_instruction(migraphx::make_op("pad", {{"pads", pads}, {"value", val}}), input);
mm->add_instruction(
migraphx::make_op(
"pooling",
{{"mode", "max"}, {"padding", {0, 0}}, {"stride", {1, 1}}, {"lengths", {2, 2}}}),
ins_pad);
{{"mode", "max"}, {"padding", {0, 0, 1, 1}}, {"stride", {1, 1}}, {"lengths", {2, 2}}}),
input);
auto prog = optimize_onnx("maxpool_same_upper_test.onnx");
......@@ -2664,10 +2730,11 @@ TEST_CASE(reshape_non_standard_test)
EXPECT(p == prog);
}
TEST_CASE(resize_downsample_f_test)
TEST_CASE(resize_downsample_c_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto* mm = p.get_main_module();
std::vector<float> ds = {1.0f, 1.0f, 0.6f, 0.6f};
migraphx::shape ss{migraphx::shape::float_type, {4}};
mm->add_literal(migraphx::literal{ss, ds});
......@@ -2678,23 +2745,22 @@ TEST_CASE(resize_downsample_f_test)
mm->add_instruction(migraphx::make_op("undefined"));
migraphx::shape si{migraphx::shape::int32_type, {1, 1, 1, 2}};
std::vector<int> ind = {4, 7};
std::vector<int> ind = {0, 2};
auto li = mm->add_literal(migraphx::literal(si, ind));
auto lrsp = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {8}}}), inx);
auto r = mm->add_instruction(migraphx::make_op("gather", {{"axis", 0}}), lrsp, li);
mm->add_return({r});
auto prog = migraphx::parse_onnx("resize_downsample_f_test.onnx");
auto prog = migraphx::parse_onnx("resize_downsample_c_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(resize_downsample_c_test)
TEST_CASE(resize_downsample_f_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto* mm = p.get_main_module();
std::vector<float> ds = {1.0f, 1.0f, 0.6f, 0.6f};
migraphx::shape ss{migraphx::shape::float_type, {4}};
mm->add_literal(migraphx::literal{ss, ds});
......@@ -2705,15 +2771,83 @@ TEST_CASE(resize_downsample_c_test)
mm->add_instruction(migraphx::make_op("undefined"));
migraphx::shape si{migraphx::shape::int32_type, {1, 1, 1, 2}};
std::vector<int> ind = {0, 2};
std::vector<int> ind = {0, 3};
auto li = mm->add_literal(migraphx::literal(si, ind));
auto lrsp = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {8}}}), inx);
auto r = mm->add_instruction(migraphx::make_op("gather", {{"axis", 0}}), lrsp, li);
mm->add_return({r});
auto prog = migraphx::parse_onnx("resize_downsample_c_test.onnx");
auto prog = migraphx::parse_onnx("resize_downsample_f_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(resize_downsample_linear_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape ss{migraphx::shape::float_type, {4}};
std::vector<float> ds = {1, 1, 0.6, 0.5};
mm->add_literal(migraphx::literal(ss, ds));
migraphx::shape sx{migraphx::shape::float_type, {1, 1, 2, 4}};
auto x = mm->add_parameter("X", sx);
migraphx::shape s_ind{migraphx::shape::int32_type, {16, 1, 1, 2}};
std::vector<int> d_ind = {0, 2, 0, 2, 0, 2, 0, 2, 4, 6, 4, 6, 4, 6, 4, 6,
1, 3, 1, 3, 1, 3, 1, 3, 5, 7, 5, 7, 5, 7, 5, 7};
auto l_ind = mm->add_literal(migraphx::literal(s_ind, d_ind));
migraphx::shape s8{migraphx::shape::float_type, {8, 1, 1, 2}};
std::vector<float> d8(16, 0.5f);
auto l8 = mm->add_literal(migraphx::literal(s8, d8));
migraphx::shape s4{migraphx::shape::float_type, {4, 1, 1, 2}};
std::vector<float> d4(8, 1.0f / 3.0f);
auto l4 = mm->add_literal(migraphx::literal(s4, d4));
migraphx::shape s2{migraphx::shape::float_type, {2, 1, 1, 2}};
std::vector<float> d2(4, 0);
auto l2 = mm->add_literal(migraphx::literal(s2, d2));
migraphx::shape s1{migraphx::shape::float_type, {1, 1, 1, 2}};
std::vector<float> d1(2, 0.0f);
auto l1 = mm->add_literal(migraphx::literal(s1, d1));
mm->add_instruction(migraphx::make_op("undefined"));
auto rsp = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {8}}}), x);
auto data = mm->add_instruction(migraphx::make_op("gather", {{"axis", 0}}), rsp, l_ind);
auto slc80 = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {8}}}), data);
auto slc81 = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {8}}, {"ends", {16}}}), data);
auto diff8 = mm->add_instruction(migraphx::make_op("sub"), slc81, slc80);
auto mul8 = mm->add_instruction(migraphx::make_op("mul"), diff8, l8);
auto add8 = mm->add_instruction(migraphx::make_op("add"), mul8, slc80);
auto slc40 = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {4}}}), add8);
auto slc41 = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {4}}, {"ends", {8}}}), add8);
auto diff4 = mm->add_instruction(migraphx::make_op("sub"), slc41, slc40);
auto mul4 = mm->add_instruction(migraphx::make_op("mul"), diff4, l4);
auto add4 = mm->add_instruction(migraphx::make_op("add"), mul4, slc40);
auto slc20 = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {2}}}), add4);
auto slc21 = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {2}}, {"ends", {4}}}), add4);
auto diff2 = mm->add_instruction(migraphx::make_op("sub"), slc21, slc20);
auto mul2 = mm->add_instruction(migraphx::make_op("mul"), diff2, l2);
auto add2 = mm->add_instruction(migraphx::make_op("add"), mul2, slc20);
auto slc10 = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), add2);
auto slc11 = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), add2);
auto diff1 = mm->add_instruction(migraphx::make_op("sub"), slc11, slc10);
auto mul1 = mm->add_instruction(migraphx::make_op("mul"), diff1, l1);
auto add1 = mm->add_instruction(migraphx::make_op("add"), mul1, slc10);
mm->add_return({add1});
auto prog = migraphx::parse_onnx("resize_downsample_linear_test.onnx");
EXPECT(p == prog);
}
......@@ -2773,6 +2907,196 @@ TEST_CASE(resize_nonstd_input_test)
EXPECT(p == prog);
}
TEST_CASE(resize_upsample_linear_ac_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape ss{migraphx::shape::float_type, {4}};
std::vector<float> ds = {1, 1, 2, 2};
mm->add_literal(migraphx::literal(ss, ds));
migraphx::shape sx{migraphx::shape::float_type, {1, 1, 2, 2}};
auto x = mm->add_parameter("X", sx);
migraphx::shape s_ind{migraphx::shape::int32_type, {16, 1, 4, 4}};
std::vector<int> d_ind = {
0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 2, 2, 2, 3, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 2,
2, 2, 3, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 2, 2, 2, 3, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0,
0, 1, 2, 2, 2, 3, 0, 0, 0, 1, 2, 2, 2, 3, 2, 2, 2, 3, 2, 2, 2, 3, 0, 0, 0, 1, 2, 2, 2,
3, 2, 2, 2, 3, 2, 2, 2, 3, 0, 0, 0, 1, 2, 2, 2, 3, 2, 2, 2, 3, 2, 2, 2, 3, 0, 0, 0, 1,
2, 2, 2, 3, 2, 2, 2, 3, 2, 2, 2, 3, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 2, 3, 3, 3, 0,
1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 2, 3, 3, 3, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 2, 3,
3, 3, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 2, 3, 3, 3, 0, 1, 1, 1, 2, 3, 3, 3, 2, 3, 3,
3, 2, 3, 3, 3, 0, 1, 1, 1, 2, 3, 3, 3, 2, 3, 3, 3, 2, 3, 3, 3, 0, 1, 1, 1, 2, 3, 3, 3,
2, 3, 3, 3, 2, 3, 3, 3, 0, 1, 1, 1, 2, 3, 3, 3, 2, 3, 3, 3, 2, 3, 3, 3};
auto l_ind = mm->add_literal(migraphx::literal(s_ind, d_ind));
migraphx::shape s8{migraphx::shape::float_type, {8, 1, 4, 4}};
std::vector<float> d8 = {
0, 1.0f / 3, 2.0f / 3, 0, 0, 1.0f / 3, 2.0f / 3, 0, 0, 1.0f / 3, 2.0f / 3, 0,
0, 1.0f / 3, 2.0f / 3, 0, 0, 1.0f / 3, 2.0f / 3, 0, 0, 1.0f / 3, 2.0f / 3, 0,
0, 1.0f / 3, 2.0f / 3, 0, 0, 1.0f / 3, 2.0f / 3, 0, 0, 1.0f / 3, 2.0f / 3, 0,
0, 1.0f / 3, 2.0f / 3, 0, 0, 1.0f / 3, 2.0f / 3, 0, 0, 1.0f / 3, 2.0f / 3, 0,
0, 1.0f / 3, 2.0f / 3, 0, 0, 1.0f / 3, 2.0f / 3, 0, 0, 1.0f / 3, 2.0f / 3, 0,
0, 1.0f / 3, 2.0f / 3, 0, 0, 1.0f / 3, 2.0f / 3, 0, 0, 1.0f / 3, 2.0f / 3, 0,
0, 1.0f / 3, 2.0f / 3, 0, 0, 1.0f / 3, 2.0f / 3, 0, 0, 1.0f / 3, 2.0f / 3, 0,
0, 1.0f / 3, 2.0f / 3, 0, 0, 1.0f / 3, 2.0f / 3, 0, 0, 1.0f / 3, 2.0f / 3, 0,
0, 1.0f / 3, 2.0f / 3, 0, 0, 1.0f / 3, 2.0f / 3, 0, 0, 1.0f / 3, 2.0f / 3, 0,
0, 1.0f / 3, 2.0f / 3, 0, 0, 1.0f / 3, 2.0f / 3, 0, 0, 1.0f / 3, 2.0f / 3, 0,
0, 1.0f / 3, 2.0f / 3, 0, 0, 1.0f / 3, 2.0f / 3, 0};
auto l8 = mm->add_literal(migraphx::literal(s8, d8));
migraphx::shape s4{migraphx::shape::float_type, {4, 1, 4, 4}};
std::vector<float> d4 = {
0, 0, 0, 0, 1.0f / 3, 1.0f / 3, 1.0f / 3, 1.0f / 3,
2.0f / 3, 2.0f / 3, 2.0f / 3, 2.0f / 3, 0, 0, 0, 0,
0, 0, 0, 0, 1.0f / 3, 1.0f / 3, 1.0f / 3, 1.0f / 3,
2.0f / 3, 2.0f / 3, 2.0f / 3, 2.0f / 3, 0, 0, 0, 0,
0, 0, 0, 0, 1.0f / 3, 1.0f / 3, 1.0f / 3, 1.0f / 3,
2.0f / 3, 2.0f / 3, 2.0f / 3, 2.0f / 3, 0, 0, 0, 0,
0, 0, 0, 0, 1.0f / 3, 1.0f / 3, 1.0f / 3, 1.0f / 3,
2.0f / 3, 2.0f / 3, 2.0f / 3, 2.0f / 3, 0, 0, 0, 0};
auto l4 = mm->add_literal(migraphx::literal(s4, d4));
migraphx::shape s2{migraphx::shape::float_type, {2, 1, 4, 4}};
std::vector<float> d2(32, 0);
auto l2 = mm->add_literal(migraphx::literal(s2, d2));
migraphx::shape s1{migraphx::shape::float_type, {1, 1, 4, 4}};
std::vector<float> d1(16, 0.0f);
auto l1 = mm->add_literal(migraphx::literal(s1, d1));
mm->add_instruction(migraphx::make_op("undefined"));
auto rsp = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {4}}}), x);
auto data = mm->add_instruction(migraphx::make_op("gather", {{"axis", 0}}), rsp, l_ind);
auto slc80 = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {8}}}), data);
auto slc81 = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {8}}, {"ends", {16}}}), data);
auto diff8 = mm->add_instruction(migraphx::make_op("sub"), slc81, slc80);
auto mul8 = mm->add_instruction(migraphx::make_op("mul"), diff8, l8);
auto add8 = mm->add_instruction(migraphx::make_op("add"), mul8, slc80);
auto slc40 = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {4}}}), add8);
auto slc41 = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {4}}, {"ends", {8}}}), add8);
auto diff4 = mm->add_instruction(migraphx::make_op("sub"), slc41, slc40);
auto mul4 = mm->add_instruction(migraphx::make_op("mul"), diff4, l4);
auto add4 = mm->add_instruction(migraphx::make_op("add"), mul4, slc40);
auto slc20 = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {2}}}), add4);
auto slc21 = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {2}}, {"ends", {4}}}), add4);
auto diff2 = mm->add_instruction(migraphx::make_op("sub"), slc21, slc20);
auto mul2 = mm->add_instruction(migraphx::make_op("mul"), diff2, l2);
auto add2 = mm->add_instruction(migraphx::make_op("add"), mul2, slc20);
auto slc10 = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), add2);
auto slc11 = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), add2);
auto diff1 = mm->add_instruction(migraphx::make_op("sub"), slc11, slc10);
auto mul1 = mm->add_instruction(migraphx::make_op("mul"), diff1, l1);
auto add1 = mm->add_instruction(migraphx::make_op("add"), mul1, slc10);
mm->add_return({add1});
auto prog = migraphx::parse_onnx("resize_upsample_linear_ac_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(resize_upsample_linear_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape ss{migraphx::shape::float_type, {4}};
std::vector<float> ds = {1, 1, 2, 2};
mm->add_literal(migraphx::literal(ss, ds));
migraphx::shape sx{migraphx::shape::float_type, {1, 1, 2, 2}};
auto x = mm->add_parameter("X", sx);
migraphx::shape s_ind{migraphx::shape::int32_type, {16, 1, 4, 4}};
std::vector<int> d_ind = {
0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 2, 2, 2, 3, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 2,
2, 2, 3, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 2, 2, 2, 3, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0,
0, 1, 2, 2, 2, 3, 0, 0, 0, 1, 2, 2, 2, 3, 2, 2, 2, 3, 2, 2, 2, 3, 0, 0, 0, 1, 2, 2, 2,
3, 2, 2, 2, 3, 2, 2, 2, 3, 0, 0, 0, 1, 2, 2, 2, 3, 2, 2, 2, 3, 2, 2, 2, 3, 0, 0, 0, 1,
2, 2, 2, 3, 2, 2, 2, 3, 2, 2, 2, 3, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 2, 3, 3, 3, 0,
1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 2, 3, 3, 3, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 2, 3,
3, 3, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 2, 3, 3, 3, 0, 1, 1, 1, 2, 3, 3, 3, 2, 3, 3,
3, 2, 3, 3, 3, 0, 1, 1, 1, 2, 3, 3, 3, 2, 3, 3, 3, 2, 3, 3, 3, 0, 1, 1, 1, 2, 3, 3, 3,
2, 3, 3, 3, 2, 3, 3, 3, 0, 1, 1, 1, 2, 3, 3, 3, 2, 3, 3, 3, 2, 3, 3, 3};
auto l_ind = mm->add_literal(migraphx::literal(s_ind, d_ind));
migraphx::shape s8{migraphx::shape::float_type, {8, 1, 4, 4}};
std::vector<float> d8 = {
0, 1.0f / 3, 2.0f / 3, 0, 0, 1.0f / 3, 2.0f / 3, 0, 0, 1.0f / 3, 2.0f / 3, 0,
0, 1.0f / 3, 2.0f / 3, 0, 0, 1.0f / 3, 2.0f / 3, 0, 0, 1.0f / 3, 2.0f / 3, 0,
0, 1.0f / 3, 2.0f / 3, 0, 0, 1.0f / 3, 2.0f / 3, 0, 0, 1.0f / 3, 2.0f / 3, 0,
0, 1.0f / 3, 2.0f / 3, 0, 0, 1.0f / 3, 2.0f / 3, 0, 0, 1.0f / 3, 2.0f / 3, 0,
0, 1.0f / 3, 2.0f / 3, 0, 0, 1.0f / 3, 2.0f / 3, 0, 0, 1.0f / 3, 2.0f / 3, 0,
0, 1.0f / 3, 2.0f / 3, 0, 0, 1.0f / 3, 2.0f / 3, 0, 0, 1.0f / 3, 2.0f / 3, 0,
0, 1.0f / 3, 2.0f / 3, 0, 0, 1.0f / 3, 2.0f / 3, 0, 0, 1.0f / 3, 2.0f / 3, 0,
0, 1.0f / 3, 2.0f / 3, 0, 0, 1.0f / 3, 2.0f / 3, 0, 0, 1.0f / 3, 2.0f / 3, 0,
0, 1.0f / 3, 2.0f / 3, 0, 0, 1.0f / 3, 2.0f / 3, 0, 0, 1.0f / 3, 2.0f / 3, 0,
0, 1.0f / 3, 2.0f / 3, 0, 0, 1.0f / 3, 2.0f / 3, 0, 0, 1.0f / 3, 2.0f / 3, 0,
0, 1.0f / 3, 2.0f / 3, 0, 0, 1.0f / 3, 2.0f / 3, 0};
auto l8 = mm->add_literal(migraphx::literal(s8, d8));
migraphx::shape s4{migraphx::shape::float_type, {4, 1, 4, 4}};
std::vector<float> d4 = {
0, 0, 0, 0, 1.0f / 3, 1.0f / 3, 1.0f / 3, 1.0f / 3,
2.0f / 3, 2.0f / 3, 2.0f / 3, 2.0f / 3, 0, 0, 0, 0,
0, 0, 0, 0, 1.0f / 3, 1.0f / 3, 1.0f / 3, 1.0f / 3,
2.0f / 3, 2.0f / 3, 2.0f / 3, 2.0f / 3, 0, 0, 0, 0,
0, 0, 0, 0, 1.0f / 3, 1.0f / 3, 1.0f / 3, 1.0f / 3,
2.0f / 3, 2.0f / 3, 2.0f / 3, 2.0f / 3, 0, 0, 0, 0,
0, 0, 0, 0, 1.0f / 3, 1.0f / 3, 1.0f / 3, 1.0f / 3,
2.0f / 3, 2.0f / 3, 2.0f / 3, 2.0f / 3, 0, 0, 0, 0};
auto l4 = mm->add_literal(migraphx::literal(s4, d4));
migraphx::shape s2{migraphx::shape::float_type, {2, 1, 4, 4}};
std::vector<float> d2(32, 0);
auto l2 = mm->add_literal(migraphx::literal(s2, d2));
migraphx::shape s1{migraphx::shape::float_type, {1, 1, 4, 4}};
std::vector<float> d1(16, 0.0f);
auto l1 = mm->add_literal(migraphx::literal(s1, d1));
mm->add_instruction(migraphx::make_op("undefined"));
auto rsp = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {4}}}), x);
auto data = mm->add_instruction(migraphx::make_op("gather", {{"axis", 0}}), rsp, l_ind);
auto slc80 = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {8}}}), data);
auto slc81 = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {8}}, {"ends", {16}}}), data);
auto diff8 = mm->add_instruction(migraphx::make_op("sub"), slc81, slc80);
auto mul8 = mm->add_instruction(migraphx::make_op("mul"), diff8, l8);
auto add8 = mm->add_instruction(migraphx::make_op("add"), mul8, slc80);
auto slc40 = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {4}}}), add8);
auto slc41 = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {4}}, {"ends", {8}}}), add8);
auto diff4 = mm->add_instruction(migraphx::make_op("sub"), slc41, slc40);
auto mul4 = mm->add_instruction(migraphx::make_op("mul"), diff4, l4);
auto add4 = mm->add_instruction(migraphx::make_op("add"), mul4, slc40);
auto slc20 = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {2}}}), add4);
auto slc21 = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {2}}, {"ends", {4}}}), add4);
auto diff2 = mm->add_instruction(migraphx::make_op("sub"), slc21, slc20);
auto mul2 = mm->add_instruction(migraphx::make_op("mul"), diff2, l2);
auto add2 = mm->add_instruction(migraphx::make_op("add"), mul2, slc20);
auto slc10 = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), add2);
auto slc11 = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), add2);
auto diff1 = mm->add_instruction(migraphx::make_op("sub"), slc11, slc10);
auto mul1 = mm->add_instruction(migraphx::make_op("mul"), diff1, l1);
auto add1 = mm->add_instruction(migraphx::make_op("add"), mul1, slc10);
mm->add_return({add1});
auto prog = migraphx::parse_onnx("resize_upsample_linear_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(resize_upsample_pc_test)
{
migraphx::program p;
......@@ -2983,6 +3307,51 @@ TEST_CASE(slice_5arg_test)
EXPECT(p == prog);
}
TEST_CASE(slice_5arg_reverse_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {5, 5}});
mm->add_literal({{migraphx::shape::int32_type, {2}}, {-1, 1}});
mm->add_literal({{migraphx::shape::int32_type, {2}}, {-1, -2}});
mm->add_literal({{migraphx::shape::int32_type, {2}}, {-5, -1}});
mm->add_literal({{migraphx::shape::int32_type, {2}}, {-1, -3}});
auto slice_out = mm->add_instruction(
migraphx::make_op("slice",
{{"axes", {-1, -2}}, {"starts", {-4, -3}}, {"ends", {2147483647, -1}}}),
l0);
auto ret = mm->add_instruction(migraphx::make_op("reverse", {{"axes", {-1}}}), slice_out);
mm->add_return({ret});
auto prog = migraphx::parse_onnx("slice_5arg_reverse_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(slice_5arg_step_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {5, 5}});
mm->add_literal({{migraphx::shape::int32_type, {2}}, {-2, 2}});
mm->add_literal({{migraphx::shape::int32_type, {2}}, {-1, -2}});
mm->add_literal({{migraphx::shape::int32_type, {2}}, {-5, -1}});
mm->add_literal({{migraphx::shape::int32_type, {2}}, {-1, -3}});
auto slice_out = mm->add_instruction(
migraphx::make_op("slice",
{{"axes", {-1, -2}}, {"starts", {-4, -3}}, {"ends", {2147483647, -1}}}),
l0);
auto reverse_out =
mm->add_instruction(migraphx::make_op("reverse", {{"axes", {-1}}}), slice_out);
auto step_out = mm->add_instruction(
migraphx::make_op("step", {{"axes", {-1, -2}}, {"steps", {2, 2}}}), reverse_out);
mm->add_return({step_out});
auto prog = migraphx::parse_onnx("slice_5arg_step_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(slice_max_end_test)
{
migraphx::program p;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment