bias.hpp 2.54 KB
Newer Older
carlushuang's avatar
carlushuang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.

#pragma once

#include <ostream>
#include <string>
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha.hpp"

// keep sync with BlockAttentionBiasEnum
enum class bias_enum
{
    no_bias          = 0,
    elementwise_bias = 1,
    alibi            = 2,
};

struct bias_info
{
    bias_enum type;
    /*
     * simple dispatch logic
     *
     * if type == elementwise_bias:
     *      if rank_info == 0:
     *           bias is 1*1*s*s
     *      elif rank_info == 1:
     *           bias is 1*h*s*s
     *      elif rank_info == 2:
     *           bias is b*h*s*s
     *
     * elif type == alibi:
     *       if rank_info == 0:
     *           alibi in 1*h
     *       elif rank_info == 1:
     *           alibi in b*h
     */
    int rank_info;

    void serialize(std::ostream& os) const
    {
        if(type == bias_enum::no_bias)
            os << "n";
        else if(type == bias_enum::elementwise_bias)
        {
            os << "e";
            if(rank_info != 0)
            {
                os << "[" << rank_info << "]";
            }
        }
        else if(type == bias_enum::alibi)
        {
            os << "alibi";
            if(rank_info != 0)
            {
                os << "[" << rank_info << "]";
            }
        }
    }

    static bias_info decode(std::string str)
    {
        bias_info info{bias_enum::no_bias, 0};
        if(str == "0" || str == "n")
        {
            info.type = bias_enum::no_bias;
        }
        else if(str.compare(0, 1, "1") == 0 || str.compare(0, 1, "e") == 0 ||
                str.compare(0, 11, "elementwise") == 0)
        {
            info.type    = bias_enum::elementwise_bias;
            auto found_0 = str.find(':');
            if(found_0 != std::string::npos)
            {
                std::string e  = str.substr(found_0 + 1);
                info.rank_info = atoi(e.c_str());
            }
        }
        else if(str.compare(0, 1, "2") == 0 || str.compare(0, 1, "a") == 0 ||
                str.compare(0, 5, "alibi") == 0)
        {
            info.type    = bias_enum::alibi;
            auto found_0 = str.find(':');
            if(found_0 != std::string::npos)
            {
                std::string e  = str.substr(found_0 + 1);
                info.rank_info = atoi(e.c_str());
            }
        }
        return info;
    }

    friend std::ostream& operator<<(std::ostream& os, const bias_info& bi)
    {
        bi.serialize(os);
        return os;
    }
};