test_layernorm.cpp 2.25 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53

#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/operators.hpp>

migraphx::instruction_ref add_layernorm(migraphx::program& p, std::vector<size_t> dims)
{
    auto x = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, dims});
    auto scale =
        p.add_parameter("scale", migraphx::shape{migraphx::shape::float_type, {dims.back()}});
    auto bias =
        p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {dims.back()}});
    auto epsilon  = p.add_literal(1e-12f);
    auto exponent = p.add_literal(2.0f);

    auto mean            = p.add_instruction(migraphx::op::reduce_mean({2}), x);
    auto mean_mbcast     = p.add_instruction(migraphx::op::multibroadcast{{dims}}, mean);
    auto sub             = p.add_instruction(migraphx::op::sub{}, x, mean_mbcast);
    auto exponent_mbcast = p.add_instruction(migraphx::op::multibroadcast{{dims}}, exponent);
    auto pow             = p.add_instruction(migraphx::op::pow{}, sub, exponent_mbcast);
    auto var             = p.add_instruction(migraphx::op::reduce_mean({2}), pow);
    auto epsilon_mbcast =
        p.add_instruction(migraphx::op::multibroadcast{{1, dims.at(1), 1}}, epsilon);
    auto add_epsilon  = p.add_instruction(migraphx::op::add{}, var, epsilon_mbcast);
    auto sqrt         = p.add_instruction(migraphx::op::sqrt{}, add_epsilon);
    auto sqrt_mbcast  = p.add_instruction(migraphx::op::multibroadcast{dims}, sqrt);
    auto div          = p.add_instruction(migraphx::op::div{}, sub, sqrt_mbcast);
    auto scale_mbcast = p.add_instruction(migraphx::op::multibroadcast{dims}, scale);
    auto mul          = p.add_instruction(migraphx::op::mul{}, scale_mbcast, div);
    auto bias_mbcast  = p.add_instruction(migraphx::op::multibroadcast{dims}, bias);
    return p.add_instruction(migraphx::op::add{}, mul, bias_mbcast);
}

struct test_layernorm : verify_program<test_layernorm>
{
    migraphx::program create_program() const
    {
        migraphx::program p;
        add_layernorm(p, {1, 1, 5});
        return p;
    }
};

struct test_layernorm2 : verify_program<test_layernorm2>
{
    migraphx::program create_program() const
    {
        migraphx::program p;
        add_layernorm(p, {1, 4, 24});
        return p;
    }
};