Unverified Commit 5e132673 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Fix make_inner_storage function (#1607)

parent f1c6647d
...@@ -410,9 +410,9 @@ struct block_large ...@@ -410,9 +410,9 @@ struct block_large
}; };
template <class Size, class F> template <class Size, class F>
constexpr inner_storage<Size, F> make_inner_storage(Size, F f) static constexpr inner_storage<Size, F> make_inner_storage(Size, F f)
{ {
return {f}; return {{}, {f}};
} }
template <class Op, class T, class Read, class N, class... Ts> template <class Op, class T, class Read, class N, class... Ts>
...@@ -483,9 +483,9 @@ struct lane ...@@ -483,9 +483,9 @@ struct lane
}; };
template <class Size, class F> template <class Size, class F>
constexpr inner_storage<Size, F> make_inner_storage(Size, F f) static constexpr inner_storage<Size, F> make_inner_storage(Size, F f)
{ {
return {f}; return {{}, {f}};
} }
template <class Op, class T, class Read, class N, class U, class... Us> template <class Op, class T, class Read, class N, class U, class... Us>
......
...@@ -28,8 +28,6 @@ ...@@ -28,8 +28,6 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/op/reduce_mean.hpp>
migraphx::instruction_ref add_layernorm(migraphx::module& m, migraphx::instruction_ref add_layernorm(migraphx::module& m,
migraphx::instruction_ref x, migraphx::instruction_ref x,
std::vector<size_t> dims, std::vector<size_t> dims,
...@@ -42,14 +40,14 @@ migraphx::instruction_ref add_layernorm(migraphx::module& m, ...@@ -42,14 +40,14 @@ migraphx::instruction_ref add_layernorm(migraphx::module& m,
auto epsilon = m.add_literal(migraphx::literal{migraphx::shape{mgx_type}, {eps}}); auto epsilon = m.add_literal(migraphx::literal{migraphx::shape{mgx_type}, {eps}});
auto exponent = m.add_literal(migraphx::literal{migraphx::shape{mgx_type}, {2.0f}}); auto exponent = m.add_literal(migraphx::literal{migraphx::shape{mgx_type}, {2.0f}});
auto mean = m.add_instruction(migraphx::op::reduce_mean({2}), x); auto mean = m.add_instruction(migraphx::make_op("reduce_mean", {{"axes", {2}}}), x);
auto mean_mbcast = auto mean_mbcast =
m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", dims}}), mean); m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", dims}}), mean);
auto sub = m.add_instruction(migraphx::make_op("sub"), x, mean_mbcast); auto sub = m.add_instruction(migraphx::make_op("sub"), x, mean_mbcast);
auto exponent_mbcast = auto exponent_mbcast =
m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", dims}}), exponent); m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", dims}}), exponent);
auto pow = m.add_instruction(migraphx::make_op("pow"), sub, exponent_mbcast); auto pow = m.add_instruction(migraphx::make_op("pow"), sub, exponent_mbcast);
auto var = m.add_instruction(migraphx::op::reduce_mean({2}), pow); auto var = m.add_instruction(migraphx::make_op("reduce_mean", {{"axes", {2}}}), pow);
auto epsilon_mbcast = m.add_instruction( auto epsilon_mbcast = m.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {1, dims.at(1), 1}}}), epsilon); migraphx::make_op("multibroadcast", {{"out_lens", {1, dims.at(1), 1}}}), epsilon);
auto add_epsilon = m.add_instruction(migraphx::make_op("add"), var, epsilon_mbcast); auto add_epsilon = m.add_instruction(migraphx::make_op("add"), var, epsilon_mbcast);
...@@ -91,6 +89,19 @@ struct test_layernorm2 : verify_program<test_layernorm2> ...@@ -91,6 +89,19 @@ struct test_layernorm2 : verify_program<test_layernorm2>
} }
}; };
struct test_layernorm_large : verify_program<test_layernorm_large>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
std::vector<size_t> dims = {1, 32, 262144};
auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, dims});
add_layernorm(*mm, x, dims);
return p;
}
};
struct test_layernorm_fp16 : verify_program<test_layernorm_fp16> struct test_layernorm_fp16 : verify_program<test_layernorm_fp16>
{ {
migraphx::program create_program() const migraphx::program create_program() const
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment