Unverified Commit 5e132673 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Fix make_inner_storage function (#1607)

parent f1c6647d
......@@ -410,9 +410,9 @@ struct block_large
};
template <class Size, class F>
constexpr inner_storage<Size, F> make_inner_storage(Size, F f)
static constexpr inner_storage<Size, F> make_inner_storage(Size, F f)
{
return {f};
return {{}, {f}};
}
template <class Op, class T, class Read, class N, class... Ts>
......@@ -483,9 +483,9 @@ struct lane
};
template <class Size, class F>
constexpr inner_storage<Size, F> make_inner_storage(Size, F f)
static constexpr inner_storage<Size, F> make_inner_storage(Size, F f)
{
return {f};
return {{}, {f}};
}
template <class Op, class T, class Read, class N, class U, class... Us>
......
......@@ -28,8 +28,6 @@
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/op/reduce_mean.hpp>
migraphx::instruction_ref add_layernorm(migraphx::module& m,
migraphx::instruction_ref x,
std::vector<size_t> dims,
......@@ -42,14 +40,14 @@ migraphx::instruction_ref add_layernorm(migraphx::module& m,
auto epsilon = m.add_literal(migraphx::literal{migraphx::shape{mgx_type}, {eps}});
auto exponent = m.add_literal(migraphx::literal{migraphx::shape{mgx_type}, {2.0f}});
auto mean = m.add_instruction(migraphx::op::reduce_mean({2}), x);
auto mean = m.add_instruction(migraphx::make_op("reduce_mean", {{"axes", {2}}}), x);
auto mean_mbcast =
m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", dims}}), mean);
auto sub = m.add_instruction(migraphx::make_op("sub"), x, mean_mbcast);
auto exponent_mbcast =
m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", dims}}), exponent);
auto pow = m.add_instruction(migraphx::make_op("pow"), sub, exponent_mbcast);
auto var = m.add_instruction(migraphx::op::reduce_mean({2}), pow);
auto var = m.add_instruction(migraphx::make_op("reduce_mean", {{"axes", {2}}}), pow);
auto epsilon_mbcast = m.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {1, dims.at(1), 1}}}), epsilon);
auto add_epsilon = m.add_instruction(migraphx::make_op("add"), var, epsilon_mbcast);
......@@ -91,6 +89,19 @@ struct test_layernorm2 : verify_program<test_layernorm2>
}
};
struct test_layernorm_large : verify_program<test_layernorm_large>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
std::vector<size_t> dims = {1, 32, 262144};
auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, dims});
add_layernorm(*mm, x, dims);
return p;
}
};
struct test_layernorm_fp16 : verify_program<test_layernorm_fp16>
{
migraphx::program create_program() const
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment