Commit 4a39a0f7 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

Merge branch 'develop' of github.com:ROCmSoftwarePlatform/AMDMIGraphX into add-conv_bn_add-test

parents 5564172e bb827865
......@@ -20,7 +20,8 @@ struct parse_biasadd : op_parser<parse_biasadd>
uint64_t axis = 1; // assume output of previous layer is in NCHW (broadcast on channel)
auto l0 = info.add_instruction(
make_op("broadcast", {{"axis", axis}, {"dims", args[0]->get_shape().lens()}}), args[1]);
make_op("broadcast", {{"axis", axis}, {"out_lens", args[0]->get_shape().lens()}}),
args[1]);
return info.add_instruction(make_op("add"), args[0], l0);
}
};
......
......@@ -62,16 +62,7 @@ struct parse_conv : op_parser<parse_conv>
calculate_padding(0, pads, input_dims[2], op.stride[0], op.dilation[0], weight_h);
calculate_padding(1, pads, input_dims[3], op.stride[1], op.dilation[1], weight_w);
if(pads[0] != pads[2] || pads[1] != pads[3])
{
std::vector<int64_t> padding = {0, 0, pads[0], pads[1], 0, 0, pads[2], pads[3]};
l0 = info.add_instruction(migraphx::make_op("pad", {{"pads", padding}}), l0);
}
else
{
op.padding[0] = pads[0];
op.padding[1] = pads[1];
}
op.padding = std::vector<size_t>(pads.begin(), pads.end());
}
else if(pad_mode.find("VALID") != std::string::npos)
{
......
......@@ -46,10 +46,12 @@ struct parse_matmul : op_parser<parse_matmul>
// swap the last two elements
std::iter_swap(perm.end() - 1, perm.end() - 2);
auto l1 = (transa) ? info.add_instruction(make_op("transpose", {{"dims", perm}}), args[0])
: args[0];
auto l2 = (transb) ? info.add_instruction(make_op("transpose", {{"dims", perm}}), args[1])
: args[1];
auto l1 = (transa)
? info.add_instruction(make_op("transpose", {{"permutation", perm}}), args[0])
: args[0];
auto l2 = (transb)
? info.add_instruction(make_op("transpose", {{"permutation", perm}}), args[1])
: args[1];
return info.add_instruction(make_op("dot"), l1, l2);
}
......
......@@ -57,20 +57,7 @@ struct parse_pooling : op_parser<parse_pooling>
calculate_padding(0, pads, input_dims[2], op.stride[0], 1, op.lengths[0]);
calculate_padding(1, pads, input_dims[3], op.stride[1], 1, op.lengths[1]);
if(pads[0] != pads[2] || pads[1] != pads[3])
{
std::vector<int64_t> padding = {0, 0, pads[0], pads[1], 0, 0, pads[2], pads[3]};
l0 = info.add_instruction(
migraphx::make_op(
"pad",
{{"pads", padding}, {"value", std::numeric_limits<float>::lowest()}}),
l0);
}
else
{
op.padding[0] = pads[0];
op.padding[1] = pads[1];
}
op.padding = std::vector<size_t>(pads.begin(), pads.end());
}
}
return info.add_instruction(op, l0);
......
......@@ -23,9 +23,9 @@ struct parse_relu6 : op_parser<parse_relu6>
auto max_val = info.add_literal(6.0f);
min_val =
info.add_instruction(make_op("multibroadcast", {{"output_lens", input_lens}}), min_val);
info.add_instruction(make_op("multibroadcast", {{"out_lens", input_lens}}), min_val);
max_val =
info.add_instruction(make_op("multibroadcast", {{"output_lens", input_lens}}), max_val);
info.add_instruction(make_op("multibroadcast", {{"out_lens", input_lens}}), max_val);
return info.add_instruction(make_op("clip"), args.front(), min_val, max_val);
}
};
......
......@@ -20,7 +20,7 @@ struct parse_transpose : op_parser<parse_transpose>
auto perm = args[1]->eval().get<int32_t>().to_vector();
std::vector<int64_t> dims(perm.begin(), perm.end());
return info.add_instruction(make_op("transpose", {{"dims", dims}}), args.front());
return info.add_instruction(make_op("transpose", {{"permutation", dims}}), args.front());
}
};
......
......@@ -17,6 +17,7 @@
#include <migraphx/instruction.hpp>
#include <migraphx/config.hpp>
#include <migraphx/tf.hpp>
#include <migraphx/common.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/tf/tf_parser.hpp>
......@@ -34,20 +35,20 @@ bool tf_parser::should_transpose(instruction_ref ins) const
instruction_ref tf_parser::to_nhwc(instruction_ref ins) const
{
if(should_transpose(ins))
return mm->add_instruction(make_op("transpose", {{"dims", {0, 2, 3, 1}}}), ins);
return mm->add_instruction(make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), ins);
return ins;
}
instruction_ref tf_parser::to_nchw(instruction_ref ins) const
{
if(should_transpose(ins))
return mm->add_instruction(make_op("transpose", {{"dims", {0, 3, 1, 2}}}), ins);
return mm->add_instruction(make_op("transpose", {{"permutation", {0, 3, 1, 2}}}), ins);
return ins;
}
instruction_ref tf_parser::to_kcxy(instruction_ref ins) const
{
return mm->add_instruction(make_op("transpose", {{"dims", {3, 2, 0, 1}}}), ins);
return mm->add_instruction(make_op("transpose", {{"permutation", {3, 2, 0, 1}}}), ins);
}
std::vector<instruction_ref> tf_parser::to_nchw(const std::vector<instruction_ref>& args) const
......@@ -74,66 +75,11 @@ instruction_ref tf_parser::node_info::make_contiguous(instruction_ref ins) const
return mm->add_instruction(make_op("contiguous"), ins);
}
std::vector<std::size_t> compute_broadcasted_lens(std::vector<std::size_t> s0,
std::vector<std::size_t> s1)
{
// Example:
// s0 = (3,2,4,5) and s1 = (2,1,1)
//
// In this case we need to broadcast (:,1,1) portion of
// s1 plus broadcast the 1st dimension of s1
// giving output_lens = (3,2,4,5)
//
// Another example:
// s0 = (3,2,1,5) and s1 = (2,7,5)
// In this case we need to broadcast the (:,:,1:,:) axis
// of s0 plus the 1st dimension of s1 giving
// output_lens = (3,2,7,5)
if(s0.size() > s1.size())
{
s0.swap(s1);
}
std::vector<std::size_t> out_lens(s1);
auto offset = s1.size() - s0.size();
std::transform(
s0.begin(), s0.end(), s1.begin() + offset, out_lens.begin() + offset, [&](auto a, auto b) {
if(a != b and a != 1 and b != 1)
{
MIGRAPHX_THROW("COMPUTE_BROADCASTLEN: shape {" + to_string_range(s0) + "} and {" +
to_string_range(s1) + "} mismatch!");
}
return std::max(a, b);
});
return out_lens;
}
instruction_ref tf_parser::node_info::add_broadcastable_binary_op(const std::string& op_name,
instruction_ref arg0,
instruction_ref arg1) const
{
if(arg0->get_shape().lens() != arg1->get_shape().lens())
{
// Get lengths for both arguments
auto s0 = arg0->get_shape().lens();
auto s1 = arg1->get_shape().lens();
auto out_lens = compute_broadcasted_lens(s0, s1);
auto l0 = arg0;
if(arg0->get_shape().lens() != out_lens)
l0 = add_instruction(make_op("multibroadcast", {{"output_lens", out_lens}}), arg0);
auto l1 = arg1;
if(arg1->get_shape().lens() != out_lens)
l1 = add_instruction(make_op("multibroadcast", {{"output_lens", out_lens}}), arg1);
return add_instruction(make_op(op_name), l0, l1);
}
else
{
return add_instruction(make_op(op_name), {arg0, arg1});
}
return add_common_op(*mm, make_op(op_name), {arg0, arg1});
}
int64_t tf_parser::parse_axis(const int64_t dim, const size_t num_dims) const
......
......@@ -224,23 +224,24 @@ std::vector<value>& get_array_throw(const std::shared_ptr<value_base_impl>& x)
return *a;
}
value* find_impl(const std::shared_ptr<value_base_impl>& x, const std::string& key)
template <class T>
T* find_impl(const std::shared_ptr<value_base_impl>& x, const std::string& key, T* end)
{
auto* a = if_array_impl(x);
if(a == nullptr)
return nullptr;
return end;
auto* lookup = x->if_object();
if(lookup == nullptr)
return nullptr;
return end;
auto it = lookup->find(key);
if(it == lookup->end())
return a->data() + a->size();
return end;
return std::addressof((*a)[it->second]);
}
value* value::find(const std::string& pkey) { return find_impl(x, pkey); }
value* value::find(const std::string& pkey) { return find_impl(x, pkey, this->end()); }
const value* value::find(const std::string& pkey) const { return find_impl(x, pkey); }
const value* value::find(const std::string& pkey) const { return find_impl(x, pkey, this->end()); }
bool value::contains(const std::string& pkey) const
{
const auto* it = find(pkey);
......
......@@ -3,14 +3,14 @@ function(add_api_test TEST_NAME TEST_SRC TEST_DIR)
set(NAME test_api_${TEST_NAME})
add_executable(${NAME} EXCLUDE_FROM_ALL ${TEST_SRC})
rocm_clang_tidy_check(${NAME})
target_link_libraries(${NAME} migraphx_c)
target_link_libraries(${NAME} migraphx_c migraphx)
target_include_directories(${NAME} PUBLIC ../include)
add_test(NAME ${NAME} COMMAND $<TARGET_FILE:${NAME}> WORKING_DIRECTORY ${TEST_DIR})
add_dependencies(tests ${NAME})
add_dependencies(check ${NAME})
endfunction()
add_api_test(compile_options test_compile_options.cpp ${TEST_ONNX_DIR})
add_api_test(ref test_cpu.cpp ${TEST_ONNX_DIR})
add_api_test(save_load test_save_load.cpp ${TEST_ONNX_DIR})
add_api_test(op test_op_construct.cpp ${TEST_ONNX_DIR})
......
#include <migraphx/migraphx.h>
#include <migraphx/migraphx.hpp>
#include <migraphx/compile_options.hpp>
#include "test.hpp"
TEST_CASE(compile_options_api_test)
{
migraphx::api::compile_options options;
options.set_offload_copy(false);
options.set_fast_math(false);
const auto* s_options = reinterpret_cast<const migraphx::MIGRAPHX_INLINE_NS::compile_options*>(
options.get_handle_ptr());
CHECK(s_options->fast_math == false);
CHECK(s_options->offload_copy == false);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
......@@ -163,4 +163,16 @@ TEST_CASE(get_main_module)
p.print();
}
TEST_CASE(set_loop_default_iter_num)
{
migraphx::onnx_options option;
option.set_default_loop_iterations(15);
auto p = migraphx::parse_onnx("loop_default_test.onnx", option);
auto out_shapes = p.get_output_shapes();
std::vector<std::size_t> out_lens0 = {1};
EXPECT(out_shapes[0].lengths() == out_lens0);
std::vector<std::size_t> out_lens1 = {15, 1};
EXPECT(out_shapes[1].lengths() == out_lens1);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
......@@ -7,8 +7,8 @@ TEST_CASE(load_and_run)
{
auto p = migraphx::parse_onnx("conv_relu_maxpool_test.onnx");
auto shapes_before = p.get_output_shapes();
migraphx_compile_options options;
options.offload_copy = true;
migraphx::compile_options options;
options.set_offload_copy();
p.compile(migraphx::target("gpu"), options);
auto shapes_after = p.get_output_shapes();
CHECK(shapes_before.size() == 1);
......@@ -30,8 +30,8 @@ TEST_CASE(if_pl_test)
auto run_prog = [&](auto cond) {
auto p = migraphx::parse_onnx("if_pl_test.onnx");
auto shapes_before = p.get_output_shapes();
migraphx_compile_options options;
options.offload_copy = true;
migraphx::compile_options options;
options.set_offload_copy();
p.compile(migraphx::target("gpu"), options);
auto shapes_after = p.get_output_shapes();
CHECK(shapes_before.size() == 1);
......@@ -74,4 +74,62 @@ TEST_CASE(if_pl_test)
}
}
TEST_CASE(loop_test)
{
auto run_prog = [&](int64_t max_iter_num) {
migraphx::onnx_options parse_options;
parse_options.set_default_loop_iterations(max_iter_num);
auto p = migraphx::parse_onnx("loop_default_test.onnx", parse_options);
auto shapes_before = p.get_output_shapes();
migraphx::compile_options options;
options.set_offload_copy();
p.compile(migraphx::target("gpu"), options);
auto shapes_after = p.get_output_shapes();
CHECK(shapes_before.size() == 2);
CHECK(bool{shapes_before.front() == shapes_after.front()});
migraphx::program_parameters pp;
auto param_shapes = p.get_parameter_shapes();
auto aas = param_shapes["a"];
std::vector<float> xd = {1.0f};
pp.add("a", migraphx::argument(aas, xd.data()));
auto bbs = param_shapes["b"];
std::vector<float> yd = {2.0};
pp.add("b", migraphx::argument(bbs, yd.data()));
auto outputs = p.eval(pp);
auto output = outputs[0];
auto lens = output.get_shape().lengths();
auto elem_num =
std::accumulate(lens.begin(), lens.end(), 1, std::multiplies<std::size_t>());
float* data_ptr = reinterpret_cast<float*>(output.data());
std::vector<std::vector<float>> ret;
ret.push_back({data_ptr, data_ptr + elem_num});
output = outputs[1];
lens = output.get_shape().lengths();
elem_num = std::accumulate(lens.begin(), lens.end(), 1, std::multiplies<std::size_t>());
data_ptr = reinterpret_cast<float*>(output.data());
ret.push_back({data_ptr, data_ptr + elem_num});
return ret;
};
{
auto result_vector = run_prog(10);
std::vector<float> gold0 = {2.0f};
EXPECT(result_vector.at(0) == gold0);
std::vector<float> gold1 = {-2, 4, 0, 0, 0, 0, 0, 0, 0, 0};
EXPECT(result_vector.at(1) == gold1);
}
{
auto result_vector = run_prog(15);
std::vector<float> gold0 = {2.0f};
EXPECT(result_vector.at(0) == gold0);
std::vector<float> gold1 = {-2, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
EXPECT(result_vector.at(1) == gold1);
}
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
......@@ -22,8 +22,8 @@ TEST_CASE(load_save_json)
std::string filename = "migraphx_api_load_save.json";
auto p1 = migraphx::parse_onnx("conv_relu_maxpool_test.onnx");
auto s1 = p1.get_output_shapes();
migraphx_file_options options;
options.format = "json";
migraphx::file_options options;
options.set_file_format("json");
migraphx::save(p1, filename.c_str(), options);
auto p2 = migraphx::load(filename.c_str(), options);
......
......@@ -98,6 +98,22 @@ TEST_CASE(nested_tuple)
EXPECT(a1.to_string() != a3.to_string());
}
TEST_CASE(tuple_construct)
{
migraphx::shape s{{migraphx::shape{migraphx::shape::float_type, {4}},
migraphx::shape{migraphx::shape::int8_type, {3}}}};
migraphx::argument a{s};
EXPECT(a.get_sub_objects().size() == 2);
EXPECT(a.get_shape() == s);
auto b = a; // NOLINT
EXPECT(a.get_shape() == b.get_shape());
EXPECT(a.get_sub_objects().size() == 2);
EXPECT(a.get_sub_objects()[0] == b.get_sub_objects()[0]);
EXPECT(a.get_sub_objects()[1] == b.get_sub_objects()[1]);
EXPECT(a == b);
}
TEST_CASE(tuple_visit)
{
auto a1 = make_tuple(3, 3.0);
......
......@@ -40,7 +40,7 @@ TEST_CASE(after_literal_transpose)
auto l = m.add_literal(get_2x2());
EXPECT(m.get_output_shapes().back().standard());
EXPECT(not m.get_output_shapes().back().transposed());
auto t = m.add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l);
auto t = m.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l);
m.add_instruction(pass_op{}, t);
EXPECT(not m.get_output_shapes().back().standard());
EXPECT(m.get_output_shapes().back().transposed());
......@@ -58,7 +58,7 @@ TEST_CASE(after_literal_broadcast)
EXPECT(m.get_output_shapes().back().standard());
EXPECT(not m.get_output_shapes().back().broadcasted());
auto b = m.add_instruction(
migraphx::make_op("broadcast", {{"axis", 0}, {"dims", l1->get_shape().lens()}}), l2);
migraphx::make_op("broadcast", {{"axis", 0}, {"out_lens", l1->get_shape().lens()}}), l2);
m.add_instruction(pass_op{}, b);
EXPECT(not m.get_output_shapes().back().standard());
EXPECT(m.get_output_shapes().back().broadcasted());
......@@ -74,7 +74,7 @@ TEST_CASE(after_param_transpose)
auto l = m.add_parameter("2x2", {migraphx::shape::float_type, {2, 2}});
EXPECT(m.get_output_shapes().back().standard());
EXPECT(not m.get_output_shapes().back().transposed());
auto t = m.add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l);
auto t = m.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l);
m.add_instruction(pass_op{}, t);
EXPECT(not m.get_output_shapes().back().standard());
EXPECT(m.get_output_shapes().back().transposed());
......@@ -92,7 +92,7 @@ TEST_CASE(after_param_broadcast)
EXPECT(m.get_output_shapes().back().standard());
EXPECT(not m.get_output_shapes().back().broadcasted());
auto b = m.add_instruction(
migraphx::make_op("broadcast", {{"axis", 0}, {"dims", l1->get_shape().lens()}}), l2);
migraphx::make_op("broadcast", {{"axis", 0}, {"out_lens", l1->get_shape().lens()}}), l2);
m.add_instruction(pass_op{}, b);
EXPECT(not m.get_output_shapes().back().standard());
EXPECT(m.get_output_shapes().back().broadcasted());
......
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/instruction.hpp>
#include <basic_ops.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/ranges.hpp>
......@@ -113,15 +114,16 @@ TEST_CASE(depth_test)
TEST_CASE(undefined_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto undef = mm->add_instruction(migraphx::make_op("undefined"));
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
mm->add_instruction(migraphx::make_op("undefined"));
mm->add_instruction(sum_op{}, one, two);
auto count = std::distance(mm->begin(), mm->end());
run_pass(p);
EXPECT(std::distance(mm->begin(), mm->end()) == count - 1);
EXPECT(not mm->has_instruction(undef));
EXPECT(
std::none_of(mm->begin(), mm->end(), [](auto&& ins) { return ins.name() == "undefined"; }));
auto result = p.eval({}).back();
EXPECT(result == migraphx::literal{3});
EXPECT(result != migraphx::literal{4});
......@@ -195,4 +197,24 @@ TEST_CASE(unused_module)
EXPECT(not migraphx::contains(p.get_modules(), m1));
}
TEST_CASE(param_not_eliminated)
{
auto create_program = [] {
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::int32_type, {2, 2}};
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
mm->add_parameter("z", s);
auto sum = mm->add_instruction(migraphx::make_op("add"), x, y);
mm->add_return({sum});
return p;
};
auto p = create_program();
run_pass(p);
EXPECT(p == create_program());
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
#include <migraphx/decompose.hpp>
#include <migraphx/pass_manager.hpp>
#include <basic_ops.hpp>
#include <migraphx/make_op.hpp>
#include <test.hpp>
void run_pass(migraphx::module& m) { migraphx::run_passes(m, {migraphx::decompose{}}); }
TEST_CASE(dot_add)
{
migraphx::module m1;
{
auto x = m1.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {2, 2}});
auto y = m1.add_parameter("y", migraphx::shape{migraphx::shape::float_type, {2, 2}});
auto z = m1.add_parameter("z", migraphx::shape{migraphx::shape::float_type, {2, 2}});
auto dot = m1.add_instruction(migraphx::make_op("dot"), x, y, z);
m1.add_instruction(migraphx::make_op("identity"), dot);
}
run_pass(m1);
migraphx::module m2;
{
auto x = m2.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {2, 2}});
auto y = m2.add_parameter("y", migraphx::shape{migraphx::shape::float_type, {2, 2}});
auto z = m2.add_parameter("z", migraphx::shape{migraphx::shape::float_type, {2, 2}});
auto dot = m2.add_instruction(migraphx::make_op("dot", {{"alpha", 1}, {"beta", 0}}), x, y);
auto add = m2.add_instruction(migraphx::make_op("add"), dot, z);
m2.add_instruction(migraphx::make_op("identity"), add);
}
EXPECT(m1 == m2);
}
TEST_CASE(dot_add_beta_float)
{
migraphx::module m1;
{
auto x = m1.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {2, 2}});
auto y = m1.add_parameter("y", migraphx::shape{migraphx::shape::float_type, {2, 2}});
auto z = m1.add_parameter("z", migraphx::shape{migraphx::shape::float_type, {2, 2}});
auto dot =
m1.add_instruction(migraphx::make_op("dot", {{"alpha", 1.0}, {"beta", 0.5}}), x, y, z);
m1.add_instruction(migraphx::make_op("identity"), dot);
}
run_pass(m1);
migraphx::module m2;
{
auto x = m2.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {2, 2}});
auto y = m2.add_parameter("y", migraphx::shape{migraphx::shape::float_type, {2, 2}});
auto z = m2.add_parameter("z", migraphx::shape{migraphx::shape::float_type, {2, 2}});
auto dot = m2.add_instruction(migraphx::make_op("dot", {{"alpha", 1}, {"beta", 0}}), x, y);
auto beta =
m2.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::float_type}, {0.5}});
auto beta_broadcast = m2.add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", {2, 2}}}), beta);
auto mul = m2.add_instruction(migraphx::make_op("mul"), z, beta_broadcast);
auto add = m2.add_instruction(migraphx::make_op("add"), dot, mul);
m2.add_instruction(migraphx::make_op("identity"), add);
}
EXPECT(m1 == m2);
}
TEST_CASE(dot_add_beta_half)
{
migraphx::module m1;
{
auto x = m1.add_parameter("x", migraphx::shape{migraphx::shape::half_type, {2, 2}});
auto y = m1.add_parameter("y", migraphx::shape{migraphx::shape::half_type, {2, 2}});
auto z = m1.add_parameter("z", migraphx::shape{migraphx::shape::half_type, {2, 2}});
auto dot =
m1.add_instruction(migraphx::make_op("dot", {{"alpha", 1.0}, {"beta", 0.5}}), x, y, z);
m1.add_instruction(migraphx::make_op("identity"), dot);
}
run_pass(m1);
migraphx::module m2;
{
auto x = m2.add_parameter("x", migraphx::shape{migraphx::shape::half_type, {2, 2}});
auto y = m2.add_parameter("y", migraphx::shape{migraphx::shape::half_type, {2, 2}});
auto z = m2.add_parameter("z", migraphx::shape{migraphx::shape::half_type, {2, 2}});
auto dot = m2.add_instruction(migraphx::make_op("dot", {{"alpha", 1}, {"beta", 0}}), x, y);
auto beta =
m2.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::half_type}, {0.5}});
auto beta_broadcast = m2.add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", {2, 2}}}), beta);
auto mul = m2.add_instruction(migraphx::make_op("mul"), z, beta_broadcast);
auto add = m2.add_instruction(migraphx::make_op("add"), dot, mul);
m2.add_instruction(migraphx::make_op("identity"), add);
}
EXPECT(m1 == m2);
}
TEST_CASE(dot_add_beta_double)
{
migraphx::module m1;
{
auto x = m1.add_parameter("x", migraphx::shape{migraphx::shape::double_type, {2, 2}});
auto y = m1.add_parameter("y", migraphx::shape{migraphx::shape::double_type, {2, 2}});
auto z = m1.add_parameter("z", migraphx::shape{migraphx::shape::double_type, {2, 2}});
auto dot =
m1.add_instruction(migraphx::make_op("dot", {{"alpha", 1.0}, {"beta", 0.5}}), x, y, z);
m1.add_instruction(migraphx::make_op("identity"), dot);
}
run_pass(m1);
migraphx::module m2;
{
auto x = m2.add_parameter("x", migraphx::shape{migraphx::shape::double_type, {2, 2}});
auto y = m2.add_parameter("y", migraphx::shape{migraphx::shape::double_type, {2, 2}});
auto z = m2.add_parameter("z", migraphx::shape{migraphx::shape::double_type, {2, 2}});
auto dot = m2.add_instruction(migraphx::make_op("dot", {{"alpha", 1}, {"beta", 0}}), x, y);
auto beta =
m2.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::double_type}, {0.5}});
auto beta_broadcast = m2.add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", {2, 2}}}), beta);
auto mul = m2.add_instruction(migraphx::make_op("mul"), z, beta_broadcast);
auto add = m2.add_instruction(migraphx::make_op("add"), dot, mul);
m2.add_instruction(migraphx::make_op("identity"), add);
}
EXPECT(m1 == m2);
}
TEST_CASE(dot_add_beta_int)
{
migraphx::module m1;
{
auto x = m1.add_parameter("x", migraphx::shape{migraphx::shape::int32_type, {2, 2}});
auto y = m1.add_parameter("y", migraphx::shape{migraphx::shape::int32_type, {2, 2}});
auto z = m1.add_parameter("z", migraphx::shape{migraphx::shape::int32_type, {2, 2}});
auto dot =
m1.add_instruction(migraphx::make_op("dot", {{"alpha", 1.0}, {"beta", 0.5}}), x, y, z);
m1.add_instruction(migraphx::make_op("identity"), dot);
}
run_pass(m1);
migraphx::module m2;
{
auto x = m2.add_parameter("x", migraphx::shape{migraphx::shape::int32_type, {2, 2}});
auto y = m2.add_parameter("y", migraphx::shape{migraphx::shape::int32_type, {2, 2}});
auto z = m2.add_parameter("z", migraphx::shape{migraphx::shape::int32_type, {2, 2}});
auto dot = m2.add_instruction(migraphx::make_op("dot", {{"alpha", 1}, {"beta", 0}}), x, y);
auto beta =
m2.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::int32_type}, {0.5}});
auto beta_broadcast = m2.add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", {2, 2}}}), beta);
auto mul = m2.add_instruction(migraphx::make_op("mul"), z, beta_broadcast);
auto add = m2.add_instruction(migraphx::make_op("add"), dot, mul);
m2.add_instruction(migraphx::make_op("identity"), add);
}
EXPECT(m1 == m2);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
#include <migraphx/dom_info.hpp>
#include <migraphx/program.hpp>
#include <basic_ops.hpp>
#include <test.hpp>
TEST_CASE(dom1)
{
migraphx::module mm;
auto ins1 = mm.add_parameter("entry", {migraphx::shape::float_type});
auto ins2 = mm.add_instruction(pass_op{}, ins1);
auto ins3 = mm.add_instruction(pass_op{}, ins2);
auto ins4 = mm.add_instruction(pass_op{}, ins2);
auto ins5 = mm.add_instruction(pass_op{}, ins3, ins4);
auto ins6 = mm.add_instruction(pass_op{}, ins2);
auto dom = migraphx::compute_dominator(mm);
EXPECT(dom.strictly_dominate(ins1, ins2));
EXPECT(dom.strictly_dominate(ins2, ins3));
EXPECT(dom.strictly_dominate(ins2, ins4));
EXPECT(dom.strictly_dominate(ins2, ins5));
EXPECT(dom.strictly_dominate(ins2, ins6));
EXPECT(not dom.strictly_dominate(ins3, ins6));
EXPECT(not dom.strictly_dominate(ins4, ins6));
EXPECT(not dom.strictly_dominate(ins3, ins5));
EXPECT(not dom.strictly_dominate(ins4, ins5));
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
#include <cstdint>
#include <migraphx/instruction.hpp>
#include <migraphx/apply_alpha_beta.hpp>
#include <basic_ops.hpp>
#include <migraphx/make_op.hpp>
#include <test.hpp>
TEST_CASE(dot_apply_alpha_beta_half)
{
migraphx::module m1;
{
auto x = m1.add_parameter("x", migraphx::shape{migraphx::shape::half_type, {2, 2}});
auto y = m1.add_parameter("y", migraphx::shape{migraphx::shape::half_type, {2, 2}});
auto z = m1.add_parameter("z", migraphx::shape{migraphx::shape::half_type, {2, 2}});
auto dot_res = migraphx::insert_apply_alpha_beta(
m1, m1.end(), {x, y, z}, migraphx::make_op("dot"), 3.0f, 2.0f);
m1.add_instruction(migraphx::make_op("identity"), dot_res);
}
migraphx::module m2;
{
auto ht = migraphx::shape::half_type;
auto ft = migraphx::shape::float_type;
auto x = m2.add_parameter("x", migraphx::shape{ht, {2, 2}});
auto y = m2.add_parameter("y", migraphx::shape{ht, {2, 2}});
auto z = m2.add_parameter("z", migraphx::shape{ht, {2, 2}});
auto alpha_literal = m2.add_literal(3.0f);
auto alpha_broadcast = m2.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", x->get_shape().lens()}}),
alpha_literal);
auto x_float = m2.add_instruction(migraphx::make_op("convert", {{"target_type", ft}}), x);
auto x_alpha_float = m2.add_instruction(migraphx::make_op("mul"), alpha_broadcast, x_float);
auto x_half =
m2.add_instruction(migraphx::make_op("convert", {{"target_type", ht}}), x_alpha_float);
auto dot_res = m2.add_instruction(migraphx::make_op("dot"), x_half, y);
auto beta_literal = m2.add_literal(2.0f);
auto z_float = m2.add_instruction(migraphx::make_op("convert", {{"target_type", ft}}), z);
auto beta_broadcast = m2.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", z->get_shape().lens()}}),
beta_literal);
auto z_beta_float = m2.add_instruction(migraphx::make_op("mul"), z_float, beta_broadcast);
auto z_beta_half =
m2.add_instruction(migraphx::make_op("convert", {{"target_type", ht}}), z_beta_float);
auto z_add = m2.add_instruction(migraphx::make_op("add"), dot_res, z_beta_half);
m2.add_instruction(migraphx::make_op("identity"), z_add);
}
EXPECT(m1 == m2);
}
TEST_CASE(dot_apply_alpha_beta_double)
{
migraphx::module m1;
{
auto x = m1.add_parameter("x", migraphx::shape{migraphx::shape::double_type, {2, 2}});
auto y = m1.add_parameter("y", migraphx::shape{migraphx::shape::double_type, {2, 2}});
auto z = m1.add_parameter("z", migraphx::shape{migraphx::shape::double_type, {2, 1}});
auto dot_res =
migraphx::add_apply_alpha_beta(m1, {x, y, z}, migraphx::make_op("dot"), 3.0f, 2.0f);
m1.add_instruction(migraphx::make_op("identity"), dot_res);
}
migraphx::module m2;
{
auto dt = migraphx::shape::double_type;
auto x = m2.add_parameter("x", migraphx::shape{dt, {2, 2}});
auto y = m2.add_parameter("y", migraphx::shape{dt, {2, 2}});
auto z = m2.add_parameter("z", migraphx::shape{dt, {2, 1}});
auto alpha_literal = m2.add_literal(3.0f);
auto alpha_broadcast = m2.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", x->get_shape().lens()}}),
alpha_literal);
auto alpha_double = m2.add_instruction(migraphx::make_op("convert", {{"target_type", dt}}),
alpha_broadcast);
auto x_alpha_double = m2.add_instruction(migraphx::make_op("mul"), alpha_double, x);
auto dot_res = m2.add_instruction(migraphx::make_op("dot"), x_alpha_double, y);
auto z_broadcast =
m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 2}}}), z);
auto beta_literal = m2.add_literal(2.0f);
auto beta_broadcast = m2.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", z_broadcast->get_shape().lens()}}),
beta_literal);
auto beta_double =
m2.add_instruction(migraphx::make_op("convert", {{"target_type", dt}}), beta_broadcast);
auto z_beta_double = m2.add_instruction(migraphx::make_op("mul"), z_broadcast, beta_double);
auto z_add = m2.add_instruction(migraphx::make_op("add"), dot_res, z_beta_double);
m2.add_instruction(migraphx::make_op("identity"), z_add);
}
EXPECT(m1 == m2);
}
TEST_CASE(quant_dot_apply_alpha_beta)
{
migraphx::module m1;
{
auto x = m1.add_parameter("x", migraphx::shape{migraphx::shape::int8_type, {2, 2}});
auto y = m1.add_parameter("y", migraphx::shape{migraphx::shape::int8_type, {2, 2}});
auto z = m1.add_parameter("z", migraphx::shape{migraphx::shape::int32_type, {2, 2}});
auto dot_res = migraphx::insert_apply_alpha_beta(m1,
m1.end(),
{x, y, z},
migraphx::make_op("quant_dot"),
migraphx::literal{int32_t{3}},
migraphx::literal{int32_t{2}});
m1.add_instruction(migraphx::make_op("identity"), dot_res);
}
migraphx::module m2;
{
auto i8 = migraphx::shape::int8_type;
auto i32 = migraphx::shape::int32_type;
auto x = m2.add_parameter("x", migraphx::shape{i8, {2, 2}});
auto y = m2.add_parameter("y", migraphx::shape{i8, {2, 2}});
auto z = m2.add_parameter("z", migraphx::shape{i32, {2, 2}});
auto alpha_literal = m2.add_literal(int32_t(3));
auto alpha_broadcast = m2.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", x->get_shape().lens()}}),
alpha_literal);
auto x_i32 = m2.add_instruction(migraphx::make_op("convert", {{"target_type", i32}}), x);
auto x_alpha_i32 = m2.add_instruction(migraphx::make_op("mul"), alpha_broadcast, x_i32);
auto x_i8 =
m2.add_instruction(migraphx::make_op("convert", {{"target_type", i8}}), x_alpha_i32);
auto dot_res = m2.add_instruction(migraphx::make_op("quant_dot"), x_i8, y);
auto beta_literal = m2.add_literal(int32_t(2));
auto beta_broadcast = m2.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", z->get_shape().lens()}}),
beta_literal);
auto z_beta_i32 = m2.add_instruction(migraphx::make_op("mul"), z, beta_broadcast);
auto z_add = m2.add_instruction(migraphx::make_op("add"), dot_res, z_beta_i32);
m2.add_instruction(migraphx::make_op("identity"), z_add);
}
EXPECT(m1 == m2);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
......@@ -17,7 +17,7 @@ TEST_CASE(standard_op)
migraphx::module m;
auto l = m.add_parameter("x", {migraphx::shape::float_type, {2, 2}});
auto t = m.add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l);
auto t = m.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l);
auto c = m.add_instruction(migraphx::make_op("contiguous"), t);
m.add_instruction(pass_standard_op{}, c);
auto count = std::distance(m.begin(), m.end());
......@@ -30,7 +30,7 @@ TEST_CASE(standard_op_const)
migraphx::module m;
auto l = m.add_literal(get_2x2());
auto t = m.add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l);
auto t = m.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l);
auto c = m.add_instruction(migraphx::make_op("contiguous"), t);
m.add_instruction(pass_standard_op{}, c);
run_pass(m);
......@@ -42,7 +42,7 @@ TEST_CASE(non_standard_op)
migraphx::module m;
auto l = m.add_parameter("x", {migraphx::shape::float_type, {2, 2}});
auto t = m.add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l);
auto t = m.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l);
auto c = m.add_instruction(migraphx::make_op("contiguous"), t);
m.add_instruction(pass_op{}, c);
auto count = std::distance(m.begin(), m.end());
......@@ -55,7 +55,7 @@ TEST_CASE(non_standard_op_const)
migraphx::module m;
auto l = m.add_literal(get_2x2());
auto t = m.add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l);
auto t = m.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l);
auto c = m.add_instruction(migraphx::make_op("contiguous"), t);
m.add_instruction(pass_op{}, c);
run_pass(m);
......@@ -67,7 +67,7 @@ TEST_CASE(transpose_gem)
migraphx::module m;
auto l = m.add_literal(get_2x2());
auto t = m.add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l);
auto t = m.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l);
auto c = m.add_instruction(migraphx::make_op("contiguous"), t);
auto ic = m.add_instruction(migraphx::make_op("identity"), c);
m.add_instruction(migraphx::make_op("dot"), ic, l);
......@@ -81,7 +81,7 @@ TEST_CASE(transpose_standard_op)
migraphx::module m;
auto l = m.add_parameter("x", {migraphx::shape::float_type, {2, 2}});
auto t = m.add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l);
auto t = m.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l);
auto c = m.add_instruction(migraphx::make_op("contiguous"), t);
auto sn = m.add_instruction(migraphx::make_op("sin"), c);
m.add_instruction(pass_standard_op{}, sn);
......@@ -95,7 +95,7 @@ TEST_CASE(transpose_standard_op_const)
migraphx::module m;
auto l = m.add_literal(get_2x2());
auto t = m.add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l);
auto t = m.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l);
auto c = m.add_instruction(migraphx::make_op("contiguous"), t);
auto sn = m.add_instruction(migraphx::make_op("sin"), c);
m.add_instruction(pass_standard_op{}, sn);
......@@ -123,7 +123,7 @@ TEST_CASE(non_standard_return_input)
migraphx::module m;
auto l = m.add_literal(get_2x2());
auto tl = m.add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l);
auto tl = m.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l);
auto c = m.add_instruction(migraphx::make_op("contiguous"), tl);
m.add_return({c});
auto count = std::distance(m.begin(), m.end());
......@@ -131,4 +131,32 @@ TEST_CASE(non_standard_return_input)
EXPECT(std::distance(m.begin(), m.end()) == count);
}
TEST_CASE(non_standard_flatten_op)
{
migraphx::module m;
auto l = m.add_parameter("x", {migraphx::shape::float_type, {2, 6, 6, 6}});
auto t = m.add_instruction(
migraphx::make_op("slice", {{"axes", {2, 3}}, {"starts", {1, 1}}, {"ends", {6, 6}}}), l);
auto c = m.add_instruction(migraphx::make_op("contiguous"), t);
m.add_instruction(migraphx::make_op("flatten"), c);
auto count = std::distance(m.begin(), m.end());
run_pass(m);
EXPECT(std::distance(m.begin(), m.end()) == count);
}
TEST_CASE(standard_flatten_op)
{
migraphx::module m;
auto l = m.add_parameter("x", {migraphx::shape::float_type, {2, 6, 6, 6}});
auto t = m.add_instruction(
migraphx::make_op("slice", {{"axes", {0, 1}}, {"starts", {1, 1}}, {"ends", {6, 6}}}), l);
auto c = m.add_instruction(migraphx::make_op("contiguous"), t);
m.add_instruction(migraphx::make_op("flatten"), c);
auto count = std::distance(m.begin(), m.end());
run_pass(m);
EXPECT(std::distance(m.begin(), m.end()) == (count - 1));
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment