Unverified Commit a85b183b authored by Chris Austen's avatar Chris Austen Committed by GitHub
Browse files

Final performance improvements for release (#1369)

* Improvements to handling and add constant passed to dot operator (#1280)
* Improve horizontal fusion of contiguous (#1292)
* Add pass to rewrite gelu as fast gelu (#1299)
* Add jit layernorm fusion (#1301)
parent 9a1ada1a
...@@ -186,9 +186,10 @@ struct nop ...@@ -186,9 +186,10 @@ struct nop
migraphx::shape compute_shape(const std::vector<migraphx::shape>&) const { return {}; } migraphx::shape compute_shape(const std::vector<migraphx::shape>&) const { return {}; }
}; };
inline migraphx::literal get_2x2() inline migraphx::literal get_2x2(int base = 0)
{ {
return migraphx::literal{{migraphx::shape::float_type, {2, 2}}, {1, 2, 3, 4}}; return migraphx::literal{{migraphx::shape::float_type, {2, 2}},
{base + 1, base + 2, base + 3, base + 4}};
} }
inline migraphx::literal get_2x2_transposed() inline migraphx::literal get_2x2_transposed()
......
...@@ -108,15 +108,7 @@ struct function ...@@ -108,15 +108,7 @@ struct function
}; };
template <class Stream, class Iterator> template <class Stream, class Iterator>
inline Stream& stream_range(Stream& s, Iterator start, Iterator last) Stream& stream_range(Stream& s, Iterator start, Iterator last);
{
if(start != last)
{
s << *start;
std::for_each(std::next(start), last, [&](auto&& x) { s << ", " << x; });
}
return s;
}
template <class Stream> template <class Stream>
inline Stream& operator<<(Stream& s, std::nullptr_t) inline Stream& operator<<(Stream& s, std::nullptr_t)
...@@ -136,6 +128,17 @@ inline auto operator<<(Stream& s, const Range& v) -> decltype(stream_range(s, v. ...@@ -136,6 +128,17 @@ inline auto operator<<(Stream& s, const Range& v) -> decltype(stream_range(s, v.
return s; return s;
} }
template <class Stream, class Iterator>
inline Stream& stream_range(Stream& s, Iterator start, Iterator last)
{
if(start != last)
{
s << *start;
std::for_each(std::next(start), last, [&](auto&& x) { s << ", " << x; });
}
return s;
}
template <class T> template <class T>
const T& get_value(const T& x) const T& get_value(const T& x)
{ {
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/rewrite_gelu.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/program.hpp>
#include <migraphx/ref/target.hpp>
#include <migraphx/op/convolution.hpp>
#include <migraphx/op/reshape.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/ranges.hpp>
#include <test.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/common.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/verify.hpp>
TEST_CASE(bias_gelu)
{
migraphx::shape s1{migraphx::shape::half_type, {2, 4, 8}};
migraphx::shape s2{migraphx::shape::half_type};
migraphx::module m1;
{
auto a = m1.add_parameter("a", s1);
auto b = m1.add_parameter("b", s1);
auto add1 = m1.add_instruction(migraphx::make_op("add"), a, b);
auto l1 = m1.add_literal(migraphx::literal{s2, {1.4140625f}});
auto div = add_common_op(m1, migraphx::make_op("div"), {add1, l1});
auto erf = m1.add_instruction(migraphx::make_op("erf"), div);
auto l2 = m1.add_literal(migraphx::literal{s2, {1.0f}});
auto add2 = add_common_op(m1, migraphx::make_op("add"), {erf, l2});
auto mul = m1.add_instruction(migraphx::make_op("mul"), add1, add2);
auto l3 = m1.add_literal(migraphx::literal{s2, {0.5f}});
mul = add_common_op(m1, migraphx::make_op("mul"), {mul, l3});
m1.add_return({mul});
}
migraphx::rewrite_gelu pass;
pass.apply(m1);
migraphx::dead_code_elimination dce;
dce.apply(m1);
migraphx::module m2;
{
auto a = m2.add_parameter("a", s1);
auto b = m2.add_parameter("b", s1);
auto add = m2.add_instruction(migraphx::make_op("add"), a, b);
auto l1 = m2.add_literal(migraphx::literal{s2, {1.702f}});
auto mul = add_common_op(m2, migraphx::make_op("mul"), {add, l1});
auto sig = m2.add_instruction(migraphx::make_op("neg"), mul);
sig = m2.add_instruction(migraphx::make_op("exp"), sig);
auto l2 = m2.add_literal(migraphx::literal{s2, {1.0f}});
sig = add_common_op(m2, migraphx::make_op("add"), {sig, l2});
sig = m2.add_instruction(migraphx::make_op("div"), add, sig);
m2.add_return({sig});
}
EXPECT(m1 == m2);
}
TEST_CASE(non_bias_gelu)
{
migraphx::shape s1{migraphx::shape::half_type, {2, 4, 8}};
migraphx::shape s2{migraphx::shape::half_type};
migraphx::module m1;
{
auto a = m1.add_parameter("a", s1);
auto b = m1.add_parameter("b", s1);
auto sub = m1.add_instruction(migraphx::make_op("sub"), a, b);
auto l1 = m1.add_literal(migraphx::literal{s2, {1.4140625f}});
auto div = add_common_op(m1, migraphx::make_op("div"), {sub, l1});
auto erf = m1.add_instruction(migraphx::make_op("erf"), div);
auto l2 = m1.add_literal(migraphx::literal{s2, {1.0f}});
auto add2 = add_common_op(m1, migraphx::make_op("add"), {erf, l2});
auto mul = m1.add_instruction(migraphx::make_op("mul"), sub, add2);
auto l3 = m1.add_literal(migraphx::literal{s2, {0.5f}});
mul = add_common_op(m1, migraphx::make_op("mul"), {mul, l3});
m1.add_return({mul});
}
migraphx::rewrite_gelu pass;
pass.apply(m1);
migraphx::dead_code_elimination dce;
dce.apply(m1);
migraphx::module m2;
{
auto a = m2.add_parameter("a", s1);
auto b = m2.add_parameter("b", s1);
auto sub = m2.add_instruction(migraphx::make_op("sub"), a, b);
auto l1 = m2.add_literal(migraphx::literal{s2, {1.702f}});
auto mul = add_common_op(m2, migraphx::make_op("mul"), {sub, l1});
auto sig = m2.add_instruction(migraphx::make_op("neg"), mul);
sig = m2.add_instruction(migraphx::make_op("exp"), sig);
auto l2 = m2.add_literal(migraphx::literal{s2, {1.0f}});
sig = add_common_op(m2, migraphx::make_op("add"), {sig, l2});
sig = m2.add_instruction(migraphx::make_op("div"), sub, sig);
m2.add_return({sig});
}
EXPECT(m1 == m2);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -358,7 +358,33 @@ TEST_CASE(simplify_mul_add) ...@@ -358,7 +358,33 @@ TEST_CASE(simplify_mul_add)
EXPECT(m1 == m2); EXPECT(m1 == m2);
} }
TEST_CASE(simplify_inner_broadcast) TEST_CASE(simplify_dot_add)
{
migraphx::module m1;
{
auto x = m1.add_parameter("x", {migraphx::shape::float_type, {2, 2}});
auto one = m1.add_literal(get_2x2());
auto two = m1.add_literal(get_2x2(1));
auto sum = m1.add_instruction(migraphx::make_op("add"), one, x);
auto dot = m1.add_instruction(migraphx::make_op("dot"), sum, two);
m1.add_instruction(pass_op{}, dot);
}
run_pass(m1);
migraphx::module m2;
{
auto x = m2.add_parameter("x", {migraphx::shape::float_type, {2, 2}});
auto one = m2.add_literal(get_2x2());
auto two = m2.add_literal(get_2x2(1));
auto dot1 = m2.add_instruction(migraphx::make_op("dot"), x, two);
auto dot2 = m2.add_instruction(migraphx::make_op("dot"), one, two);
auto sum = m2.add_instruction(migraphx::make_op("add"), dot1, dot2);
m2.add_instruction(pass_op{}, sum);
}
EXPECT(m1 == m2);
}
TEST_CASE(simplify_inner_broadcast1)
{ {
auto b = migraphx::op::broadcast{1, {2, 1, 4, 5}}; auto b = migraphx::op::broadcast{1, {2, 1, 4, 5}};
migraphx::module m1; migraphx::module m1;
...@@ -383,6 +409,31 @@ TEST_CASE(simplify_inner_broadcast) ...@@ -383,6 +409,31 @@ TEST_CASE(simplify_inner_broadcast)
EXPECT(m1 == m2); EXPECT(m1 == m2);
} }
TEST_CASE(simplify_inner_broadcast2)
{
auto b = migraphx::op::multibroadcast{{2, 1, 4, 5}};
migraphx::module m1;
{
auto x = m1.add_parameter("x", {migraphx::shape::int32_type, {1, 1, 1, 1}});
auto y = m1.add_parameter("y", {migraphx::shape::int32_type, {1, 1, 1, 1}});
auto xb = m1.add_instruction(b, x);
auto yb = m1.add_instruction(b, y);
auto sum = m1.add_instruction(migraphx::make_op("add"), xb, yb);
m1.add_instruction(pass_op{}, sum);
}
run_pass(m1);
migraphx::module m2;
{
auto x = m2.add_parameter("x", {migraphx::shape::int32_type, {1, 1, 1, 1}});
auto y = m2.add_parameter("y", {migraphx::shape::int32_type, {1, 1, 1, 1}});
auto sum = m2.add_instruction(migraphx::make_op("add"), x, y);
auto sumb = m2.add_instruction(b, sum);
m2.add_instruction(pass_op{}, sumb);
}
EXPECT(m1 == m2);
}
TEST_CASE(simplify_add_conv1) TEST_CASE(simplify_add_conv1)
{ {
migraphx::module m; migraphx::module m;
......
...@@ -39,6 +39,15 @@ void run_pass(migraphx::module& m) ...@@ -39,6 +39,15 @@ void run_pass(migraphx::module& m)
migraphx::run_passes(m, {migraphx::simplify_reshapes{}, migraphx::dead_code_elimination{}}); migraphx::run_passes(m, {migraphx::simplify_reshapes{}, migraphx::dead_code_elimination{}});
} }
inline std::vector<std::vector<std::size_t>> to_lens(const std::vector<migraphx::shape>& shapes)
{
std::vector<std::vector<std::size_t>> result;
std::transform(shapes.begin(), shapes.end(), std::back_inserter(result), [&](const auto& s) {
return s.lens();
});
return result;
}
TEST_CASE(double_contig) TEST_CASE(double_contig)
{ {
migraphx::program p; migraphx::program p;
...@@ -1275,4 +1284,82 @@ TEST_CASE(transpose_slice_single_transpose) ...@@ -1275,4 +1284,82 @@ TEST_CASE(transpose_slice_single_transpose)
EXPECT(m1 == m2); EXPECT(m1 == m2);
} }
TEST_CASE(transpose_slice_non_packed_axis)
{
migraphx::module m1;
{
auto x = m1.add_parameter("x", {migraphx::shape::float_type, {2, 384, 36, 64}});
auto transpose =
m1.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1, 3}}}), x);
auto slice = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {12}}}),
transpose);
auto sqrt = m1.add_instruction(migraphx::make_op("sqrt"), slice);
m1.add_return({sqrt});
}
auto output_shapes = m1.get_output_shapes();
run_pass(m1);
EXPECT(m1.get_output_shapes() == output_shapes);
migraphx::module m2;
{
auto x = m2.add_parameter("x", {migraphx::shape::float_type, {2, 384, 36, 64}});
auto unsqueeze =
m2.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {2}}, {"steps", {12}}}), x);
auto transpose = m2.add_instruction(
migraphx::make_op("transpose", {{"permutation", {3, 0, 2, 1, 4}}}), unsqueeze);
auto slice = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), transpose);
auto squeeze = m2.add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), slice);
auto sqrt = m2.add_instruction(migraphx::make_op("sqrt"), squeeze);
m2.add_return({sqrt});
}
EXPECT(m1 == m2);
}
TEST_CASE(transpose_slice_non_packed_multi_axis)
{
migraphx::module m1;
{
auto x = m1.add_parameter("x", {migraphx::shape::float_type, {2, 384, 36, 64}});
auto transpose =
m1.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1, 3}}}), x);
auto slice1 = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {12}}}),
transpose);
auto slice2 = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {12}}, {"ends", {24}}}),
transpose);
auto transpose2 = m1.add_instruction(
migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), slice2);
auto slice3 = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {24}}, {"ends", {36}}}),
transpose);
m1.add_return({slice1, transpose2, slice3});
}
auto output_shapes = m1.get_output_shapes();
run_pass(m1);
EXPECT(to_lens(m1.get_output_shapes()) == to_lens(output_shapes));
migraphx::module m2;
{
auto x = m2.add_parameter("x", {migraphx::shape::float_type, {2, 384, 36, 64}});
auto unsqueeze =
m2.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {2}}, {"steps", {12}}}), x);
auto transpose = m2.add_instruction(
migraphx::make_op("transpose", {{"permutation", {3, 0, 2, 1, 4}}}), unsqueeze);
auto slice1 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), transpose);
auto squeeze1 = m2.add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), slice1);
auto slice2 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), transpose);
auto squeeze2 = m2.add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), slice2);
auto transpose2 = m2.add_instruction(
migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), squeeze2);
auto slice3 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {2}}, {"ends", {3}}}), transpose);
auto squeeze3 = m2.add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), slice3);
m2.add_return({squeeze1, transpose2, squeeze3});
}
EXPECT(m1.sort() == m2.sort());
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/apply_alpha_beta.hpp>
struct gemm_add_broadcast1 : verify_program<gemm_add_broadcast1>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape m1_shape{migraphx::shape::float_type, {1, 2, 3}};
migraphx::shape m2_shape{migraphx::shape::float_type, {1, 3, 4}};
migraphx::shape m3_shape{migraphx::shape::float_type, {1, 1, 4}};
auto l1 = mm->add_parameter("1", m1_shape);
auto l2 = mm->add_parameter("2", m2_shape);
auto l3 = mm->add_parameter("3", m3_shape);
auto l3_b =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 2, 4}}}), l3);
auto dot = mm->add_instruction(migraphx::make_op("dot"), l1, l2);
mm->add_instruction(migraphx::make_op("add"), dot, l3_b);
return p;
}
};
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/apply_alpha_beta.hpp>
struct gemm_add_broadcast2 : verify_program<gemm_add_broadcast2>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape m1_shape{migraphx::shape::float_type, {1, 2, 3}};
migraphx::shape m2_shape{migraphx::shape::float_type, {1, 3, 4}};
migraphx::shape m3_shape{migraphx::shape::float_type, {1, 2, 1}};
auto l1 = mm->add_parameter("1", m1_shape);
auto l2 = mm->add_parameter("2", m2_shape);
auto l3 = mm->add_parameter("3", m3_shape);
auto l3_b =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 2, 4}}}), l3);
auto dot = mm->add_instruction(migraphx::make_op("dot"), l1, l2);
mm->add_instruction(migraphx::make_op("add"), dot, l3_b);
return p;
}
};
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct test_add_gelu_half : verify_program<test_add_gelu_half>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
std::vector<size_t> input_lens{1, 1, 5};
auto x = mm->add_parameter("x", {migraphx::shape::half_type, input_lens});
auto y = mm->add_parameter("y", {migraphx::shape::half_type, input_lens});
auto half = mm->add_literal(migraphx::literal{{migraphx::shape::half_type}, {0.5f}});
auto one = mm->add_literal(migraphx::literal{{migraphx::shape::half_type}, {1.0f}});
auto sqrt2 = mm->add_literal(migraphx::literal{{migraphx::shape::half_type}, {M_SQRT2}});
auto add = mm->add_instruction(migraphx::make_op("add"), x, y);
auto half_mbcast = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}), half);
auto mul_half = mm->add_instruction(migraphx::make_op("mul"), add, half_mbcast);
auto sqrt2_mbcast = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}), sqrt2);
auto div = mm->add_instruction(migraphx::make_op("div"), add, sqrt2_mbcast);
auto erf = mm->add_instruction(migraphx::make_op("erf"), div);
auto one_mbcast = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}), one);
auto add_one = mm->add_instruction(migraphx::make_op("add"), erf, one_mbcast);
mm->add_instruction(migraphx::make_op("mul"), mul_half, add_one);
return p;
}
};
...@@ -68,7 +68,7 @@ struct test_layernorm : verify_program<test_layernorm> ...@@ -68,7 +68,7 @@ struct test_layernorm : verify_program<test_layernorm>
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
std::vector<size_t> dims = {1, 1, 5}; std::vector<size_t> dims = {1, 2, 5};
auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, dims}); auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, dims});
add_layernorm(*mm, x, dims); add_layernorm(*mm, x, dims);
return p; return p;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment