Unverified Commit 15acaee9 authored by Umang Yadav's avatar Umang Yadav Committed by GitHub
Browse files

Preserve layout of fused kernel for `layernorm+pointwise` (#2185)

parent 74ba9649
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#include <migraphx/permutation.hpp>
#include <migraphx/gpu/prefuse_ops.hpp> #include <migraphx/gpu/prefuse_ops.hpp>
#include <migraphx/match/layernorm.hpp> #include <migraphx/match/layernorm.hpp>
#include <migraphx/check_shapes.hpp> #include <migraphx/check_shapes.hpp>
...@@ -45,40 +46,42 @@ struct layernorm_base ...@@ -45,40 +46,42 @@ struct layernorm_base
} }
shape compute_shape(std::vector<shape> inputs, std::vector<module_ref> mods) const shape compute_shape(std::vector<shape> inputs, std::vector<module_ref> mods) const
{ {
std::size_t nargs = 1; std::size_t nargs = N;
if(not mods.empty()) if(not mods.empty())
{ {
auto* pm = mods.front(); auto* pm = mods.front();
nargs = pm->get_parameter_names().size(); nargs += pm->get_parameter_names().size() - 1;
} }
check_shapes{inputs, static_cast<const Derived&>(*this)}.has(nargs + N); check_shapes{inputs, static_cast<const Derived&>(*this)}.has(nargs);
auto s = inputs.at(0); auto s = inputs.front();
auto t = s.type(); auto t = s.type();
if(not mods.empty()) if(not mods.empty())
t = mods.front()->get_output_shapes().front().type(); t = mods.front()->get_output_shapes().front().type();
if(s.scalar())
{ // Scalar output if all inputs are scalar
return s; if(inputs.front().elements() == 1 and
} all_of(inputs, [](const auto& ss) { return ss.scalar(); }))
else if(s.broadcasted()) return inputs.front();
{ auto l_s = shape::from_permutation(
return {t, s.lens()}; t, s.lens(), find_permutation(std::vector<shape>(inputs.begin(), inputs.begin() + N)));
} // just prelayernorm or preadd_layernorm
else if(nargs <= N)
{ return l_s;
return s.with_lens(t, s.lens()); // else, layernorm + pointwise fusion, preserve layout of fused op
} std::vector<shape> lp_s(inputs.begin() + N, inputs.end());
lp_s.insert(lp_s.begin(), l_s);
return shape::from_permutation(t, s.lens(), find_permutation(lp_s));
} }
}; };
struct layernorm : layernorm_base<layernorm, 0> struct layernorm : layernorm_base<layernorm, 1>
{ {
std::string name() const { return "gpu::prelayernorm"; } std::string name() const { return "gpu::prelayernorm"; }
}; };
MIGRAPHX_REGISTER_OP(layernorm); MIGRAPHX_REGISTER_OP(layernorm);
struct add_layernorm : layernorm_base<add_layernorm, 1> struct add_layernorm : layernorm_base<add_layernorm, 2>
{ {
std::string name() const { return "gpu::preadd_layernorm"; } std::string name() const { return "gpu::preadd_layernorm"; }
}; };
......
...@@ -49,7 +49,8 @@ migraphx::instruction_ref add_layernorm(migraphx::module& m, ...@@ -49,7 +49,8 @@ migraphx::instruction_ref add_layernorm(migraphx::module& m,
auto pow = m.add_instruction(migraphx::make_op("pow"), sub, exponent_mbcast); auto pow = m.add_instruction(migraphx::make_op("pow"), sub, exponent_mbcast);
auto var = m.add_instruction(migraphx::make_op("reduce_mean", {{"axes", {2}}}), pow); auto var = m.add_instruction(migraphx::make_op("reduce_mean", {{"axes", {2}}}), pow);
auto epsilon_mbcast = m.add_instruction( auto epsilon_mbcast = m.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {1, dims.at(1), 1}}}), epsilon); migraphx::make_op("multibroadcast", {{"out_lens", {dims.at(0), dims.at(1), 1}}}), epsilon);
auto add_epsilon = m.add_instruction(migraphx::make_op("add"), var, epsilon_mbcast); auto add_epsilon = m.add_instruction(migraphx::make_op("add"), var, epsilon_mbcast);
auto sqrt = m.add_instruction(migraphx::make_op("sqrt"), add_epsilon); auto sqrt = m.add_instruction(migraphx::make_op("sqrt"), add_epsilon);
auto sqrt_mbcast = auto sqrt_mbcast =
...@@ -57,7 +58,8 @@ migraphx::instruction_ref add_layernorm(migraphx::module& m, ...@@ -57,7 +58,8 @@ migraphx::instruction_ref add_layernorm(migraphx::module& m,
auto div = m.add_instruction(migraphx::make_op("div"), sub, sqrt_mbcast); auto div = m.add_instruction(migraphx::make_op("div"), sub, sqrt_mbcast);
auto scale_mbcast = auto scale_mbcast =
m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", dims}}), scale); m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", dims}}), scale);
auto mul = m.add_instruction(migraphx::make_op("mul"), scale_mbcast, div); auto mul = m.add_instruction(migraphx::make_op("mul"), div, scale_mbcast);
auto bias_mbcast = auto bias_mbcast =
m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", dims}}), bias); m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", dims}}), bias);
return m.add_instruction(migraphx::make_op("add"), mul, bias_mbcast); return m.add_instruction(migraphx::make_op("add"), mul, bias_mbcast);
...@@ -161,3 +163,21 @@ struct test_layernorm_triadd_large : verify_program<test_layernorm_triadd_large> ...@@ -161,3 +163,21 @@ struct test_layernorm_triadd_large : verify_program<test_layernorm_triadd_large>
return p; return p;
} }
}; };
struct test_add_layernorm_add_gemm_nonstd : verify_program<test_add_layernorm_add_gemm_nonstd>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
auto s =
migraphx::shape::from_permutation(migraphx::shape::float_type, {8, 1, 16}, {1, 2, 0});
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
auto z = mm->add_parameter("z", migraphx::shape{migraphx::shape::float_type, {8, 16, 64}});
auto add = mm->add_instruction(migraphx::make_op("add"), x, y);
auto layernorm_ins = add_layernorm(*mm, add, s.lens());
mm->add_instruction(migraphx::make_op("dot"), layernorm_ins, z);
return p;
}
};
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment