Commit db1a954c authored by Paul's avatar Paul
Browse files

Merge branch 'develop' into fuse-dot-weights

parents f92195d0 333860ce
......@@ -100,7 +100,7 @@ struct parse_conv : op_parser<parse_conv>
{
MIGRAPHX_THROW("padding should have 4 values");
}
if(padding[0] != padding[2] || padding[1] != padding[3])
if(padding[0] != padding[2] or padding[1] != padding[3])
{
MIGRAPHX_THROW("migraphx does not support asymetric padding");
}
......
......@@ -90,7 +90,7 @@ struct parse_depthwiseconv : op_parser<parse_depthwiseconv>
calculate_padding(0, pads, input_dims[2], op.stride[0], op.dilation[0], weight_h);
calculate_padding(1, pads, input_dims[3], op.stride[1], op.dilation[1], weight_w);
if(pads[0] != pads[2] || pads[1] != pads[3])
if(pads[0] != pads[2] or pads[1] != pads[3])
{
std::vector<int64_t> padding = {0, 0, pads[0], pads[1], 0, 0, pads[2], pads[3]};
l0 = info.add_instruction(migraphx::make_op("pad", {{"pads", padding}}), l0);
......
......@@ -42,7 +42,7 @@ struct parse_pooling : op_parser<parse_pooling>
tf_parser::node_info info,
std::vector<instruction_ref> args) const
{
if(!starts_with(opd.tf_name, "Max") && !starts_with(opd.tf_name, "Av"))
if(not starts_with(opd.tf_name, "Max") and not starts_with(opd.tf_name, "Av"))
{
MIGRAPHX_THROW("tf pooling mode must be Max or Average");
}
......
......@@ -41,8 +41,9 @@ struct parse_relu6 : op_parser<parse_relu6>
const tf_parser::node_info& info,
std::vector<instruction_ref> args) const
{
auto min_val = info.add_literal(0.0f);
auto max_val = info.add_literal(6.0f);
shape::type_t output_type = args[0]->get_shape().type();
auto min_val = info.add_literal(migraphx::literal{migraphx::shape{output_type}, {0.0f}});
auto max_val = info.add_literal(migraphx::literal{migraphx::shape{output_type}, {6.0f}});
return info.add_common_op("clip", args[0], min_val, max_val);
}
......
......@@ -371,7 +371,7 @@ void tf_parser::parse_node(const std::string& name)
{
result = ops[node.op()](*this, {get_attributes(node), node.op(), mm}, args);
}
assert(!result.empty());
assert(not result.empty());
// First output has no ":" delimiter
instructions[name] = result.front();
for(size_t i = 1; i < result.size(); i++)
......@@ -458,7 +458,7 @@ literal tf_parser::parse_tensor(const tensorflow::TensorProto& t) const
{
std::vector<size_t> dims = parse_dims(t.tensor_shape());
size_t shape_size = std::accumulate(dims.begin(), dims.end(), 1, std::multiplies<size_t>());
if(!t.tensor_content().empty()) // has raw data
if(not t.tensor_content().empty()) // has raw data
{
const std::string& s = t.tensor_content();
switch(t.dtype())
......
......@@ -78,7 +78,7 @@ void tmp_dir::execute(const std::string& exe, const std::string& args) const
tmp_dir::~tmp_dir()
{
if(!enabled(MIGRAPHX_DEBUG_SAVE_TEMP_DIR{}))
if(not enabled(MIGRAPHX_DEBUG_SAVE_TEMP_DIR{}))
{
fs::remove_all(this->path);
}
......
......@@ -400,7 +400,7 @@ std::pair<value*, bool> value::insert(const value& v)
{
if(v.key.empty())
{
if(!x)
if(not x)
x = std::make_shared<array_value_holder>();
get_array_impl(x).push_back(v);
assert(this->if_array());
......@@ -408,7 +408,7 @@ std::pair<value*, bool> value::insert(const value& v)
}
else
{
if(!x)
if(not x)
x = std::make_shared<object_value_holder>();
auto p = x->if_object()->emplace(v.key, get_array_impl(x).size());
if(p.second)
......@@ -420,7 +420,7 @@ std::pair<value*, bool> value::insert(const value& v)
value* value::insert(const value* pos, const value& v)
{
assert(v.key.empty());
if(!x)
if(not x)
x = std::make_shared<array_value_holder>();
auto&& a = get_array_impl(x);
auto it = a.insert(a.begin() + (pos - begin()), v);
......@@ -466,7 +466,7 @@ bool compare(const value& x, const value& y, F f)
value::type_t value::get_type() const
{
if(!x)
if(not x)
return null_type;
return x->get_type();
}
......
......@@ -55,7 +55,7 @@ struct simple_custom_op final : migraphx::experimental_custom_op_base
virtual migraphx::shape compute_shape(migraphx::shapes inputs) const override
{
if(!inputs[0].standard())
if(not inputs[0].standard())
{
throw std::runtime_error("first arg must be standard shaped");
}
......
......@@ -49,6 +49,6 @@ bool create_shapes(bool dynamic_allowed)
TEST_CASE(allow_dynamic_shape) { EXPECT(create_shapes(true)); }
TEST_CASE(fail_dynamic_shape) { EXPECT(!create_shapes(false)); }
TEST_CASE(fail_dynamic_shape) { EXPECT(not create_shapes(false)); }
int main(int argc, const char* argv[]) { test::run(argc, argv); }
......@@ -187,7 +187,7 @@ TEST_CASE(print_test)
std::stringstream ss;
ss << p;
std::string s = ss.str();
EXPECT(!s.empty());
EXPECT(not s.empty());
}
TEST_CASE(param_test)
......
......@@ -47,7 +47,7 @@ TEST_CASE(is_supported)
{
auto p = create_program();
auto targets = migraphx::get_targets();
EXPECT(!targets.empty());
EXPECT(not targets.empty());
auto t = migraphx::make_target("fpga");
const auto assignments = p.get_target_assignments({t});
......
......@@ -112,12 +112,12 @@ struct mod_pass_op
migraphx::shape compute_shape(std::vector<migraphx::shape> inputs,
std::vector<migraphx::module_ref> mods) const
{
if(!mods.empty())
if(not mods.empty())
{
auto out_shapes = mods[0]->get_output_shapes();
return out_shapes[0];
}
if(!inputs.empty())
if(not inputs.empty())
{
return inputs.front();
}
......@@ -186,9 +186,10 @@ struct nop
migraphx::shape compute_shape(const std::vector<migraphx::shape>&) const { return {}; }
};
inline migraphx::literal get_2x2()
inline migraphx::literal get_2x2(int base = 0)
{
return migraphx::literal{{migraphx::shape::float_type, {2, 2}}, {1, 2, 3, 4}};
return migraphx::literal{{migraphx::shape::float_type, {2, 2}},
{base + 1, base + 2, base + 3, base + 4}};
}
inline migraphx::literal get_2x2_transposed()
......
......@@ -345,7 +345,7 @@ inline std::ostream& operator<<(std::ostream& os, const color& c)
template <class T, class F>
void failed(T x, const char* msg, const char* func, const char* file, int line, F f)
{
if(!bool(x.value()))
if(not bool(x.value()))
{
std::cout << func << std::endl;
std::cout << file << ":" << line << ":" << std::endl;
......
......@@ -39,8 +39,8 @@ TEST_CASE(literal_test)
migraphx::literal l2 = l1; // NOLINT
EXPECT(l1 == l2);
EXPECT(l1.at<int>(0) == 1);
EXPECT(!l1.empty());
EXPECT(!l2.empty());
EXPECT(not l1.empty());
EXPECT(not l2.empty());
migraphx::literal l3{};
migraphx::literal l4{};
......
......@@ -38,7 +38,6 @@
#include <migraphx/onnx.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/op/convolution.hpp>
#include <migraphx/op/pad.hpp>
#include <migraphx/op/pooling.hpp>
#include <migraphx/op/lrn.hpp>
#include <migraphx/op/reshape.hpp>
......
......@@ -3988,7 +3988,8 @@ TEST_CASE(not_test)
std::vector<char> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<bool> gold(data.size());
std::transform(data.begin(), data.end(), gold.begin(), [](bool n) -> bool { return !n; });
std::transform(
data.begin(), data.end(), gold.begin(), [](bool n) -> bool { return not n; });
EXPECT(migraphx::verify_range(results_vector, gold));
}
}
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/rewrite_gelu.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/program.hpp>
#include <migraphx/ref/target.hpp>
#include <migraphx/op/convolution.hpp>
#include <migraphx/op/reshape.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/ranges.hpp>
#include <test.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/common.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/verify.hpp>
TEST_CASE(bias_gelu)
{
migraphx::shape s1{migraphx::shape::half_type, {2, 4, 8}};
migraphx::shape s2{migraphx::shape::half_type};
migraphx::module m1;
{
auto a = m1.add_parameter("a", s1);
auto b = m1.add_parameter("b", s1);
auto add1 = m1.add_instruction(migraphx::make_op("add"), a, b);
auto l1 = m1.add_literal(migraphx::literal{s2, {1.4140625f}});
auto div = add_common_op(m1, migraphx::make_op("div"), {add1, l1});
auto erf = m1.add_instruction(migraphx::make_op("erf"), div);
auto l2 = m1.add_literal(migraphx::literal{s2, {1.0f}});
auto add2 = add_common_op(m1, migraphx::make_op("add"), {erf, l2});
auto mul = m1.add_instruction(migraphx::make_op("mul"), add1, add2);
auto l3 = m1.add_literal(migraphx::literal{s2, {0.5f}});
mul = add_common_op(m1, migraphx::make_op("mul"), {mul, l3});
m1.add_return({mul});
}
migraphx::rewrite_gelu pass;
pass.apply(m1);
migraphx::dead_code_elimination dce;
dce.apply(m1);
migraphx::module m2;
{
auto a = m2.add_parameter("a", s1);
auto b = m2.add_parameter("b", s1);
auto add = m2.add_instruction(migraphx::make_op("add"), a, b);
auto l1 = m2.add_literal(migraphx::literal{s2, {1.702f}});
auto mul = add_common_op(m2, migraphx::make_op("mul"), {add, l1});
auto sig = m2.add_instruction(migraphx::make_op("neg"), mul);
sig = m2.add_instruction(migraphx::make_op("exp"), sig);
auto l2 = m2.add_literal(migraphx::literal{s2, {1.0f}});
sig = add_common_op(m2, migraphx::make_op("add"), {sig, l2});
sig = m2.add_instruction(migraphx::make_op("div"), add, sig);
m2.add_return({sig});
}
EXPECT(m1 == m2);
}
TEST_CASE(non_bias_gelu)
{
migraphx::shape s1{migraphx::shape::half_type, {2, 4, 8}};
migraphx::shape s2{migraphx::shape::half_type};
migraphx::module m1;
{
auto a = m1.add_parameter("a", s1);
auto b = m1.add_parameter("b", s1);
auto sub = m1.add_instruction(migraphx::make_op("sub"), a, b);
auto l1 = m1.add_literal(migraphx::literal{s2, {1.4140625f}});
auto div = add_common_op(m1, migraphx::make_op("div"), {sub, l1});
auto erf = m1.add_instruction(migraphx::make_op("erf"), div);
auto l2 = m1.add_literal(migraphx::literal{s2, {1.0f}});
auto add2 = add_common_op(m1, migraphx::make_op("add"), {erf, l2});
auto mul = m1.add_instruction(migraphx::make_op("mul"), sub, add2);
auto l3 = m1.add_literal(migraphx::literal{s2, {0.5f}});
mul = add_common_op(m1, migraphx::make_op("mul"), {mul, l3});
m1.add_return({mul});
}
migraphx::rewrite_gelu pass;
pass.apply(m1);
migraphx::dead_code_elimination dce;
dce.apply(m1);
migraphx::module m2;
{
auto a = m2.add_parameter("a", s1);
auto b = m2.add_parameter("b", s1);
auto sub = m2.add_instruction(migraphx::make_op("sub"), a, b);
auto l1 = m2.add_literal(migraphx::literal{s2, {1.702f}});
auto mul = add_common_op(m2, migraphx::make_op("mul"), {sub, l1});
auto sig = m2.add_instruction(migraphx::make_op("neg"), mul);
sig = m2.add_instruction(migraphx::make_op("exp"), sig);
auto l2 = m2.add_literal(migraphx::literal{s2, {1.0f}});
sig = add_common_op(m2, migraphx::make_op("add"), {sig, l2});
sig = m2.add_instruction(migraphx::make_op("div"), sub, sig);
m2.add_return({sig});
}
EXPECT(m1 == m2);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
......@@ -43,7 +43,7 @@ TEST_CASE(test_shape_assign)
migraphx::shape s1{migraphx::shape::float_type, {100, 32, 8, 8}};
migraphx::shape s2 = s1; // NOLINT
EXPECT(s1 == s2);
EXPECT(!(s1 != s2));
EXPECT(not(s1 != s2));
}
TEST_CASE(test_shape_packed_default)
......@@ -325,7 +325,7 @@ TEST_CASE(test_shape_default_copy)
migraphx::shape s1{};
migraphx::shape s2{};
EXPECT(s1 == s2);
EXPECT(!(s1 != s2));
EXPECT(not(s1 != s2));
}
TEST_CASE(test_shape_normalize_standard1)
......
......@@ -30,7 +30,6 @@
#include <migraphx/instruction.hpp>
#include <basic_ops.hpp>
#include <migraphx/make_op.hpp>
#include <test.hpp>
void run_pass(migraphx::module& m)
......@@ -358,7 +357,33 @@ TEST_CASE(simplify_mul_add)
EXPECT(m1 == m2);
}
TEST_CASE(simplify_inner_broadcast)
TEST_CASE(simplify_dot_add)
{
migraphx::module m1;
{
auto x = m1.add_parameter("x", {migraphx::shape::float_type, {2, 2}});
auto one = m1.add_literal(get_2x2());
auto two = m1.add_literal(get_2x2(1));
auto sum = m1.add_instruction(migraphx::make_op("add"), one, x);
auto dot = m1.add_instruction(migraphx::make_op("dot"), sum, two);
m1.add_instruction(pass_op{}, dot);
}
run_pass(m1);
migraphx::module m2;
{
auto x = m2.add_parameter("x", {migraphx::shape::float_type, {2, 2}});
auto one = m2.add_literal(get_2x2());
auto two = m2.add_literal(get_2x2(1));
auto dot1 = m2.add_instruction(migraphx::make_op("dot"), x, two);
auto dot2 = m2.add_instruction(migraphx::make_op("dot"), one, two);
auto sum = m2.add_instruction(migraphx::make_op("add"), dot1, dot2);
m2.add_instruction(pass_op{}, sum);
}
EXPECT(m1 == m2);
}
TEST_CASE(simplify_inner_broadcast1)
{
auto b = migraphx::op::broadcast{1, {2, 1, 4, 5}};
migraphx::module m1;
......@@ -383,6 +408,31 @@ TEST_CASE(simplify_inner_broadcast)
EXPECT(m1 == m2);
}
TEST_CASE(simplify_inner_broadcast2)
{
auto b = migraphx::op::multibroadcast{{2, 1, 4, 5}};
migraphx::module m1;
{
auto x = m1.add_parameter("x", {migraphx::shape::int32_type, {1, 1, 1, 1}});
auto y = m1.add_parameter("y", {migraphx::shape::int32_type, {1, 1, 1, 1}});
auto xb = m1.add_instruction(b, x);
auto yb = m1.add_instruction(b, y);
auto sum = m1.add_instruction(migraphx::make_op("add"), xb, yb);
m1.add_instruction(pass_op{}, sum);
}
run_pass(m1);
migraphx::module m2;
{
auto x = m2.add_parameter("x", {migraphx::shape::int32_type, {1, 1, 1, 1}});
auto y = m2.add_parameter("y", {migraphx::shape::int32_type, {1, 1, 1, 1}});
auto sum = m2.add_instruction(migraphx::make_op("add"), x, y);
auto sumb = m2.add_instruction(b, sum);
m2.add_instruction(pass_op{}, sumb);
}
EXPECT(m1 == m2);
}
TEST_CASE(simplify_add_conv1)
{
migraphx::module m;
......@@ -1477,6 +1527,48 @@ TEST_CASE(simplify_dot_horiz_flipped)
EXPECT(m1.sort() == m2.sort());
}
// test if contiguous is added as necessary for reshapes
TEST_CASE(simplify_dot_horiz_reshape)
{
auto s = migraphx::shape{migraphx::shape::int32_type, {3, 4, 4}};
migraphx::module m1;
{
auto input = m1.add_parameter("input", s);
auto a = m1.add_literal(migraphx::generate_literal(s, 0));
auto b = m1.add_literal(migraphx::generate_literal(s, 1));
auto x = m1.add_instruction(migraphx::make_op("dot"), input, a);
auto y = m1.add_instruction(migraphx::make_op("dot"), input, b);
auto x_rsp = m1.add_instruction(migraphx::make_op("reshape", {{"dims", {3, 4, 2, 2}}}), x);
auto y_rsp =
m1.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {2}}, {"steps", {2}}}), y);
auto sum = m1.add_instruction(migraphx::make_op("add"), {x_rsp, y_rsp});
m1.add_instruction(pass_op{}, sum);
}
run_pass(m1);
migraphx::module m2;
{
auto input = m2.add_parameter("input", s);
auto a = m2.add_literal(migraphx::generate_literal(s, 0));
auto b = m2.add_literal(migraphx::generate_literal(s, 1));
auto concat = m2.add_instruction(migraphx::make_op("concat", {{"axis", 2}}), a, b);
auto dot = m2.add_instruction(migraphx::make_op("dot"), input, concat);
auto x = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {0}}, {"ends", {4}}}), dot);
auto y = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {4}}, {"ends", {8}}}), dot);
auto x_cont = m2.add_instruction(migraphx::make_op("contiguous"), x);
auto x_rsp =
m2.add_instruction(migraphx::make_op("reshape", {{"dims", {3, 4, 2, 2}}}), x_cont);
auto y_rsp =
m2.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {2}}, {"steps", {2}}}), y);
auto sum = m2.add_instruction(migraphx::make_op("add"), {x_rsp, y_rsp});
m2.add_instruction(pass_op{}, sum);
}
EXPECT(m1.sort() == m2.sort());
}
TEST_CASE(simplify_conv_horiz)
{
auto s = migraphx::shape{migraphx::shape::int32_type, {8, 3, 64, 64}};
......@@ -1782,13 +1874,19 @@ TEST_CASE(simplify_mul_slice_conv_horiz_fusion)
}
EXPECT(m1.sort() == m2.sort());
}
TEST_CASE(reorder_reshape_slice)
template <std::size_t BS, bool TransposeInput>
void reorder_reshape_slice()
{
std::vector<int64_t> perm0 = {0, 2, 1, 3};
std::vector<int64_t> perm1 = {0, 2, 3, 1};
auto create_m1 = [&](std::size_t batch_size) {
migraphx::module m1;
auto s = migraphx::shape{migraphx::shape::float_type, {batch_size, 128, 1920}};
migraphx::module m1;
{
auto s = migraphx::shape{migraphx::shape::float_type, {BS, 128, 1920}};
if(TransposeInput)
{
s = migraphx::shape{migraphx::shape::float_type, {BS, 128, 1920}, {165120, 1, 128}};
}
auto input = m1.add_parameter("input", s);
auto slc0 = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {0}}, {"ends", {640}}}), input);
......@@ -1803,7 +1901,7 @@ TEST_CASE(reorder_reshape_slice)
auto c1 = m1.add_instruction(migraphx::make_op("contiguous"), slc1);
auto c2 = m1.add_instruction(migraphx::make_op("contiguous"), slc2);
std::vector<int64_t> lens = {static_cast<int64_t>(batch_size), 128, 10, 64};
std::vector<int64_t> lens = {static_cast<int64_t>(BS), 128, 10, 64};
auto r0 = m1.add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), c0);
auto r1 = m1.add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), c1);
auto r2 = m1.add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), c2);
......@@ -1815,16 +1913,23 @@ TEST_CASE(reorder_reshape_slice)
auto sum = m1.add_instruction(migraphx::make_op("add"), t0, t1);
auto ret = m1.add_instruction(migraphx::make_op("dot"), sum, t2);
m1.add_return({ret});
return m1;
};
auto create_m2 = [&](std::size_t batch_size) {
migraphx::module m2;
auto s = migraphx::shape{migraphx::shape::float_type, {batch_size, 128, 1920}};
auto input = m2.add_parameter("input", s);
std::vector<int64_t> lens = {static_cast<int64_t>(batch_size), 128, 30, 64};
auto r = m2.add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), input);
migraphx::module m2;
{
auto s = migraphx::shape{migraphx::shape::float_type, {BS, 128, 1920}};
if(TransposeInput)
{
s = migraphx::shape{migraphx::shape::float_type, {BS, 128, 1920}, {165120, 1, 128}};
}
auto input = m2.add_parameter("input", s);
auto rsp_input = input;
if(TransposeInput)
{
rsp_input = m2.add_instruction(migraphx::make_op("contiguous"), {input});
}
std::vector<int64_t> lens = {static_cast<int64_t>(BS), 128, 30, 64};
auto r = m2.add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), rsp_input);
auto slc0 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {0}}, {"ends", {10}}}), r);
......@@ -1843,27 +1948,25 @@ TEST_CASE(reorder_reshape_slice)
auto sum = m2.add_instruction(migraphx::make_op("add"), t0, t1);
auto ret = m2.add_instruction(migraphx::make_op("dot"), sum, t2);
m2.add_return({ret});
return m2;
};
auto test = [&](std::size_t batch_size) {
auto m1 = create_m1(batch_size);
run_pass(m1);
auto m2 = create_m2(batch_size);
EXPECT(m1.sort() == m2.sort());
};
test(1);
test(4);
test(8);
run_pass(m1);
EXPECT(m1.sort() == m2.sort());
}
TEST_CASE(reorder_reshape_slice_move_axis1)
TEST_CASE_REGISTER(reorder_reshape_slice<1, true>); // test if contiguous is added as necessary if
// input is transposed
TEST_CASE_REGISTER(reorder_reshape_slice<4, true>);
TEST_CASE_REGISTER(reorder_reshape_slice<8, true>);
TEST_CASE_REGISTER(reorder_reshape_slice<1, false>);
TEST_CASE_REGISTER(reorder_reshape_slice<4, false>);
TEST_CASE_REGISTER(reorder_reshape_slice<8, false>);
template <std::size_t BS>
void reorder_reshape_slice_move_axis1()
{
auto create_m1 = [](std::size_t batch_size) {
migraphx::module m1;
auto s = migraphx::shape{migraphx::shape::float_type, {batch_size, 256, 96}};
migraphx::module m1;
{
auto s = migraphx::shape{migraphx::shape::float_type, {BS, 256, 96}};
std::vector<int64_t> perm0 = {0, 2, 1, 3};
std::vector<int64_t> perm1 = {0, 2, 3, 1};
auto input = m1.add_parameter("input", s);
......@@ -1878,7 +1981,7 @@ TEST_CASE(reorder_reshape_slice_move_axis1)
auto c1 = m1.add_instruction(migraphx::make_op("contiguous"), slc1);
auto c2 = m1.add_instruction(migraphx::make_op("contiguous"), slc2);
std::vector<int64_t> lens = {static_cast<int64_t>(batch_size), 64, 4, 32};
std::vector<int64_t> lens = {static_cast<int64_t>(BS), 64, 4, 32};
auto r0 = m1.add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), c0);
auto r1 = m1.add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), c1);
auto r2 = m1.add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), c2);
......@@ -1890,50 +1993,45 @@ TEST_CASE(reorder_reshape_slice_move_axis1)
auto sum = m1.add_instruction(migraphx::make_op("add"), t0, t1);
auto ret = m1.add_instruction(migraphx::make_op("dot"), sum, t2);
m1.add_return({ret});
return m1;
};
auto create_m2 = [](std::size_t batch_size) {
migraphx::module m;
auto s = migraphx::shape{migraphx::shape::float_type, {batch_size, 256, 96}};
migraphx::module m2;
{
auto s = migraphx::shape{migraphx::shape::float_type, {BS, 256, 96}};
std::vector<int64_t> perm0 = {0, 2, 1, 3};
std::vector<int64_t> perm1 = {0, 2, 3, 1};
auto input = m.add_parameter("input", s);
std::vector<int64_t> lens = {static_cast<int64_t>(batch_size), 64, 4, 96};
auto rsp = m.add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), input);
auto slc0 = m.add_instruction(
auto input = m2.add_parameter("input", s);
std::vector<int64_t> lens = {static_cast<int64_t>(BS), 64, 4, 96};
auto rsp = m2.add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), input);
auto slc0 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {3}}, {"starts", {0}}, {"ends", {32}}}), rsp);
auto t0 = m.add_instruction(migraphx::make_op("transpose", {{"permutation", perm0}}), slc0);
auto slc1 = m.add_instruction(
auto t0 =
m2.add_instruction(migraphx::make_op("transpose", {{"permutation", perm0}}), slc0);
auto slc1 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {3}}, {"starts", {32}}, {"ends", {64}}}), rsp);
auto t1 = m.add_instruction(migraphx::make_op("transpose", {{"permutation", perm0}}), slc1);
auto slc2 = m.add_instruction(
auto t1 =
m2.add_instruction(migraphx::make_op("transpose", {{"permutation", perm0}}), slc1);
auto slc2 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {3}}, {"starts", {64}}, {"ends", {96}}}), rsp);
auto t2 = m.add_instruction(migraphx::make_op("transpose", {{"permutation", perm1}}), slc2);
auto sum = m.add_instruction(migraphx::make_op("add"), t0, t1);
auto ret = m.add_instruction(migraphx::make_op("dot"), sum, t2);
m.add_return({ret});
return m;
};
auto t2 =
m2.add_instruction(migraphx::make_op("transpose", {{"permutation", perm1}}), slc2);
auto test = [&](std::size_t batch_size) {
auto m1 = create_m1(batch_size);
auto m2 = create_m2(batch_size);
run_pass(m1);
EXPECT(m1.sort() == m2.sort());
auto sum = m2.add_instruction(migraphx::make_op("add"), t0, t1);
auto ret = m2.add_instruction(migraphx::make_op("dot"), sum, t2);
m2.add_return({ret});
};
test(4);
test(8);
run_pass(m1);
EXPECT(m1.sort() == m2.sort());
}
TEST_CASE_REGISTER(reorder_reshape_slice_move_axis1<4>);
TEST_CASE_REGISTER(reorder_reshape_slice_move_axis1<8>);
TEST_CASE(reorder_reshape_slice_move_axis2)
{
auto create_m1 = [] {
migraphx::module m1;
migraphx::module m1;
{
migraphx::shape s{migraphx::shape::float_type, {128, 96}};
auto input = m1.add_parameter("input", s);
auto slc0 = m1.add_instruction(
......@@ -1955,32 +2053,75 @@ TEST_CASE(reorder_reshape_slice_move_axis2)
auto sum = m1.add_instruction(migraphx::make_op("add"), r0, r1);
auto ret = m1.add_instruction(migraphx::make_op("mul"), sum, r2);
m1.add_return({ret});
return m1;
};
auto create_m2 = [] {
migraphx::module m;
migraphx::module m2;
{
auto s = migraphx::shape{migraphx::shape::float_type, {128, 96}};
auto input = m.add_parameter("input", s);
auto input = m2.add_parameter("input", s);
std::vector<int64_t> lens = {1, 16, 8, 96};
auto rsp = m.add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), input);
auto slc0 = m.add_instruction(
auto rsp = m2.add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), input);
auto slc0 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {3}}, {"starts", {0}}, {"ends", {32}}}), rsp);
auto slc1 = m.add_instruction(
auto slc1 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {3}}, {"starts", {32}}, {"ends", {64}}}), rsp);
auto slc2 = m.add_instruction(
auto slc2 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {3}}, {"starts", {64}}, {"ends", {96}}}), rsp);
auto sum = m.add_instruction(migraphx::make_op("add"), slc0, slc1);
auto ret = m.add_instruction(migraphx::make_op("mul"), sum, slc2);
m.add_return({ret});
auto sum = m2.add_instruction(migraphx::make_op("add"), slc0, slc1);
auto ret = m2.add_instruction(migraphx::make_op("mul"), sum, slc2);
m2.add_return({ret});
};
return m;
run_pass(m1);
EXPECT(m1.sort() == m2.sort());
}
TEST_CASE(reorder_reshape_slice_len_1)
{
migraphx::module m1;
{
migraphx::shape s{migraphx::shape::float_type, {1, 128, 3}};
auto input = m1.add_parameter("input", s);
auto slc0 = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {0}}, {"ends", {1}}}), input);
auto slc1 = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {1}}, {"ends", {2}}}), input);
auto slc2 = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {2}}, {"ends", {3}}}), input);
auto c0 = m1.add_instruction(migraphx::make_op("contiguous"), slc0);
auto c1 = m1.add_instruction(migraphx::make_op("contiguous"), slc1);
auto c2 = m1.add_instruction(migraphx::make_op("contiguous"), slc2);
std::vector<int64_t> lens = {1, 128};
auto r0 = m1.add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), c0);
auto r1 = m1.add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), c1);
auto r2 = m1.add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), c2);
auto sum = m1.add_instruction(migraphx::make_op("add"), r0, r1);
auto ret = m1.add_instruction(migraphx::make_op("mul"), sum, r2);
m1.add_return({ret});
};
migraphx::module m2;
{
auto s = migraphx::shape{migraphx::shape::float_type, {1, 128, 3}};
auto input = m2.add_parameter("input", s);
std::vector<int64_t> lens = {1, 384};
auto rsp = m2.add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), input);
auto slc0 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {128}}}), rsp);
auto slc1 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {128}}, {"ends", {256}}}), rsp);
auto slc2 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {256}}, {"ends", {384}}}), rsp);
auto sum = m2.add_instruction(migraphx::make_op("add"), slc0, slc1);
auto ret = m2.add_instruction(migraphx::make_op("mul"), sum, slc2);
m2.add_return({ret});
};
auto m1 = create_m1();
auto m2 = create_m2();
run_pass(m1);
EXPECT(m1.sort() == m2.sort());
}
......@@ -2020,15 +2161,14 @@ TEST_CASE(reorder_reshape_slice_not_apply)
EXPECT(m1.sort() == m2.sort());
}
TEST_CASE(reorder_reshape_slice_diff_dims)
template <std::size_t BS>
void reorder_reshape_slice_diff_dims()
{
auto create_m1 = [](std::size_t batch_size) {
migraphx::module m1;
auto s = migraphx::shape{migraphx::shape::float_type, {batch_size, 96, 96}};
std::vector<int64_t> perm0 = {0, 2, 1, 3};
std::vector<int64_t> perm1 = {0, 2, 3, 1};
auto input = m1.add_parameter("input", s);
auto slc0 = m1.add_instruction(
migraphx::module m1;
{
auto s = migraphx::shape{migraphx::shape::float_type, {BS, 96, 96}};
auto input = m1.add_parameter("input", s);
auto slc0 = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {0}}, {"ends", {32}}}), input);
auto slc1 = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {32}}, {"ends", {64}}}), input);
......@@ -2039,34 +2179,31 @@ TEST_CASE(reorder_reshape_slice_diff_dims)
auto c1 = m1.add_instruction(migraphx::make_op("contiguous"), slc1);
auto c2 = m1.add_instruction(migraphx::make_op("contiguous"), slc2);
std::vector<int64_t> lens = {static_cast<int64_t>(batch_size), 32, 3, 32};
std::vector<int64_t> lens1 = {static_cast<int64_t>(batch_size), 48, 2, 32};
std::vector<int64_t> lens = {static_cast<int64_t>(BS), 32, 3, 32};
std::vector<int64_t> lens1 = {static_cast<int64_t>(BS), 48, 2, 32};
auto r0 = m1.add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), c0);
auto r1 = m1.add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), c1);
auto r2 = m1.add_instruction(migraphx::make_op("reshape", {{"dims", lens1}}), c2);
m1.add_return({r0, r1, r2});
return m1;
};
auto test = [&](std::size_t batch_size) {
auto m1 = create_m1(batch_size);
auto m2 = m1;
run_pass(m1);
EXPECT(m1.sort() == m2.sort());
};
test(4);
test(8);
auto m2 = m1;
run_pass(m1);
EXPECT(m1.sort() == m2.sort());
}
TEST_CASE(reorder_slice_trans)
TEST_CASE_REGISTER(reorder_reshape_slice_diff_dims<4>);
TEST_CASE_REGISTER(reorder_reshape_slice_diff_dims<8>);
template <std::size_t BS>
void reorder_slice_trans()
{
std::vector<int64_t> perm = {0, 2, 1};
auto create_m1 = [&](std::size_t batch_size) {
migraphx::module m1;
auto s = migraphx::shape{migraphx::shape::float_type, {batch_size, 128, 1920}};
migraphx::module m1;
{
auto s = migraphx::shape{migraphx::shape::float_type, {BS, 128, 1920}};
auto input = m1.add_parameter("input", s);
auto slc0 = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {0}}, {"ends", {640}}}), input);
......@@ -2084,13 +2221,11 @@ TEST_CASE(reorder_slice_trans)
auto sum = m1.add_instruction(migraphx::make_op("add"), t0, t1);
auto ret = m1.add_instruction(migraphx::make_op("mul"), sum, t2);
m1.add_return({ret});
return m1;
};
auto create_m2 = [&](std::size_t batch_size) {
migraphx::module m2;
auto s = migraphx::shape{migraphx::shape::float_type, {batch_size, 128, 1920}};
migraphx::module m2;
{
auto s = migraphx::shape{migraphx::shape::float_type, {BS, 128, 1920}};
auto input = m2.add_parameter("input", s);
auto r = m2.add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), input);
......@@ -2104,26 +2239,21 @@ TEST_CASE(reorder_slice_trans)
auto sum = m2.add_instruction(migraphx::make_op("add"), slc0, slc1);
auto ret = m2.add_instruction(migraphx::make_op("mul"), sum, slc2);
m2.add_return({ret});
return m2;
};
auto test = [&](std::size_t batch_size) {
auto m1 = create_m1(batch_size);
run_pass(m1);
auto m2 = create_m2(batch_size);
EXPECT(m1.sort() == m2.sort());
};
test(1);
test(8);
run_pass(m1);
EXPECT(m1.sort() == m2.sort());
}
TEST_CASE(reorder_slice_trans_diff_perm)
TEST_CASE_REGISTER(reorder_slice_trans<1>);
TEST_CASE_REGISTER(reorder_slice_trans<8>);
template <std::size_t BS>
void reorder_slice_trans_diff_perm()
{
auto create_m1 = [](std::size_t batch_size) {
migraphx::module m1;
auto s = migraphx::shape{migraphx::shape::float_type, {batch_size, 128, 1920}};
migraphx::module m1;
{
auto s = migraphx::shape{migraphx::shape::float_type, {BS, 128, 1920}};
std::vector<int64_t> perm0 = {0, 2, 1};
std::vector<int64_t> perm1 = {0, 1, 2};
auto input = m1.add_parameter("input", s);
......@@ -2146,21 +2276,16 @@ TEST_CASE(reorder_slice_trans_diff_perm)
auto sum = m1.add_instruction(migraphx::make_op("add"), t0, t1);
auto ret = m1.add_instruction(migraphx::make_op("dot"), sum, t2);
m1.add_return({ret});
return m1;
};
auto test = [&](std::size_t batch_size) {
auto m1 = create_m1(batch_size);
run_pass(m1);
auto m2 = m1;
EXPECT(m1.sort() == m2.sort());
};
test(1);
test(4);
run_pass(m1);
auto m2 = m1;
EXPECT(m1.sort() == m2.sort());
}
TEST_CASE_REGISTER(reorder_slice_trans_diff_perm<1>);
TEST_CASE_REGISTER(reorder_slice_trans_diff_perm<4>);
TEST_CASE(reorder_slice_ins_deps)
{
auto create_module = [] {
......
......@@ -48,6 +48,26 @@ inline std::vector<std::vector<std::size_t>> to_lens(const std::vector<migraphx:
return result;
}
migraphx::module make_concat_multibroadcast(const std::vector<size_t>& in_lens,
const std::vector<size_t>& mbcast_lens,
const int axis)
{
migraphx::module m;
auto s = migraphx::shape{migraphx::shape::float_type, in_lens};
auto x = m.add_parameter("x", s);
auto y = m.add_parameter("y", s);
auto z = m.add_parameter("z", s);
auto xm =
m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", mbcast_lens}}), x);
auto ym =
m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", mbcast_lens}}), y);
auto zm =
m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", mbcast_lens}}), z);
auto concat = m.add_instruction(migraphx::make_op("concat", {{"axis", axis}}), xm, ym, zm);
m.add_return({concat});
return m;
}
TEST_CASE(double_contig)
{
migraphx::program p;
......@@ -337,6 +357,87 @@ TEST_CASE(nop_convert)
EXPECT(std::distance(m.begin(), m.end()) == n - 1);
}
TEST_CASE(concat_multibroadcasts1)
{
// Broadcasted batch dim, new axis < old axis
std::vector<std::size_t> in_lens = {3, 4};
std::vector<std::size_t> mbcast_lens = {2, 3, 4};
const int axis = 2;
auto m = make_concat_multibroadcast(in_lens, mbcast_lens, axis);
auto out_shape = m.get_output_shapes().back();
auto n = std::distance(m.begin(), m.end());
run_pass(m);
EXPECT(m.get_output_shapes().back().lens() == out_shape.lens());
EXPECT(std::distance(m.begin(), m.end()) == n - 2);
auto new_concat =
std::find_if(m.begin(), m.end(), [](auto ins) { return ins.name() == "concat"; });
EXPECT(bool{new_concat != m.end()});
auto cd = std::distance(m.begin(), new_concat);
auto new_mb =
std::find_if(m.begin(), m.end(), [](auto ins) { return ins.name() == "multibroadcast"; });
auto md = std::distance(m.begin(), new_mb);
EXPECT(cd == md - 1);
EXPECT(migraphx::any_cast<migraphx::op::concat>(new_concat->get_operator()).axis == 1);
}
TEST_CASE(concat_multibroadcasts2)
{
// Broadcasted middle dim, new axis == old axis
std::vector<std::size_t> in_lens = {3, 1, 4};
std::vector<std::size_t> mbcast_lens = {3, 2, 4};
const int axis = 0;
auto m = make_concat_multibroadcast(in_lens, mbcast_lens, axis);
auto out_shape = m.get_output_shapes().back();
auto n = std::distance(m.begin(), m.end());
run_pass(m);
EXPECT(m.get_output_shapes().back().lens() == out_shape.lens());
EXPECT(std::distance(m.begin(), m.end()) == n - 2);
auto new_concat =
std::find_if(m.begin(), m.end(), [](auto ins) { return ins.name() == "concat"; });
EXPECT(bool{new_concat != m.end()});
auto cd = std::distance(m.begin(), new_concat);
auto new_mb =
std::find_if(m.begin(), m.end(), [](auto ins) { return ins.name() == "multibroadcast"; });
auto md = std::distance(m.begin(), new_mb);
EXPECT(cd == md - 1);
EXPECT(migraphx::any_cast<migraphx::op::concat>(new_concat->get_operator()).axis == 0);
}
TEST_CASE(concat_multibroadcasts3)
{
// Broadcasted middle dim, new axis == old axis
std::vector<std::size_t> in_lens = {3, 1, 4};
std::vector<std::size_t> mbcast_lens = {3, 2, 4};
const int axis = 2;
auto m = make_concat_multibroadcast(in_lens, mbcast_lens, axis);
auto out_shape = m.get_output_shapes().back();
auto n = std::distance(m.begin(), m.end());
run_pass(m);
EXPECT(m.get_output_shapes().back().lens() == out_shape.lens());
EXPECT(std::distance(m.begin(), m.end()) == n - 2);
auto new_concat =
std::find_if(m.begin(), m.end(), [](auto ins) { return ins.name() == "concat"; });
EXPECT(bool{new_concat != m.end()});
auto cd = std::distance(m.begin(), new_concat);
auto new_mb =
std::find_if(m.begin(), m.end(), [](auto ins) { return ins.name() == "multibroadcast"; });
auto md = std::distance(m.begin(), new_mb);
EXPECT(cd == md - 1);
EXPECT(migraphx::any_cast<migraphx::op::concat>(new_concat->get_operator()).axis == 2);
}
TEST_CASE(concat_multibroadcasts4)
{
// Broadcasted batch dim, axis is broadcasted dim
std::vector<std::size_t> in_lens = {3, 4};
std::vector<std::size_t> mbcast_lens = {2, 3, 4};
const int axis = 0;
auto m = make_concat_multibroadcast(in_lens, mbcast_lens, axis);
auto m1 = m;
run_pass(m);
EXPECT(m1 == m);
}
TEST_CASE(concat_transpose1)
{
migraphx::module m;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment