attn.py 9.94 KB
Newer Older
Casper Hansen's avatar
Casper Hansen committed
1
import math
Haotian Tang's avatar
Haotian Tang committed
2
3
4
import torch
import torch.nn as nn
import awq_inference_engine
Casper Hansen's avatar
Casper Hansen committed
5
from torch.nn import functional as F
Casper Hansen's avatar
Casper Hansen committed
6

Casper Hansen's avatar
Casper Hansen committed
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
    t = torch.arange(end, device=freqs.device)  # type: ignore
    freqs = torch.outer(t, freqs).float()  # type: ignore
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)  # complex64
    return freqs_cis

def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
    ndim = x.ndim
    assert 0 <= 1 < ndim
    assert freqs_cis.shape == (x.shape[1], x.shape[-1])
    shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
    return freqs_cis.view(*shape)

def apply_rotary_emb(
    xq: torch.Tensor,
    xk: torch.Tensor,
    freqs_cis: torch.Tensor,
):
    xq_ = torch.view_as_complex(
        xq.float().reshape(*xq.shape[:-1], 2, -1).transpose(-2, -1).contiguous()
    )
    xk_ = torch.view_as_complex(
        xk.float().reshape(*xk.shape[:-1], 2, -1).transpose(-2, -1).contiguous()
    )
    freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
    xq_out = torch.view_as_real(xq_ * freqs_cis).transpose(-2, -1).flatten(3)
    xk_out = torch.view_as_real(xk_ * freqs_cis).transpose(-2, -1).flatten(3)
    return xq_out.type_as(xq), xk_out.type_as(xk)

Casper Hansen's avatar
Casper Hansen committed
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
def gen_slopes(n_heads, alibi_bias_max=8):
    _n_heads = 2 ** math.ceil(math.log2(n_heads))
    m = torch.arange(1, _n_heads + 1, dtype=torch.float32)
    m = m.mul(alibi_bias_max / _n_heads)
    slopes = 1.0 / torch.pow(2, m)
    if _n_heads != n_heads:
        slopes = torch.concat([slopes[1::2], slopes[::2]])[:n_heads]
    return slopes.view(1, n_heads, 1, 1)


def build_alibi_bias(
    n_heads, seq_len, full=False, alibi_bias_max=8, dtype=torch.float32
):
    alibi_bias = torch.arange(1 - seq_len, 1, dtype=torch.int32).view(1, 1, 1, seq_len)
    if full:
        alibi_bias = alibi_bias - torch.arange(1 - seq_len, 1, dtype=torch.int32).view(
            1, 1, seq_len, 1
        )
        alibi_bias = alibi_bias.abs().mul(-1)
    slopes = gen_slopes(n_heads, alibi_bias_max)
    alibi_bias = alibi_bias * slopes
    slopes = slopes.squeeze(0).squeeze(-1).squeeze(-1)
    return slopes.to(dtype=dtype), alibi_bias.to(dtype=dtype)

Haotian Tang's avatar
Haotian Tang committed
61
62
63
64
65
66
67
68

class QuantLlamaRotaryEmbedding(nn.Module):
    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
        super().__init__()

        self.dim = dim
        self.max_position_embeddings = max_position_embeddings
        self.base = base
Casper Hansen's avatar
Casper Hansen committed
69
70
71
        inv_freq = 1.0 / (
            self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)
        )
Haotian Tang's avatar
Haotian Tang committed
72
73
74
        self.register_buffer("inv_freq", inv_freq)
        # Build here to make `torch.jit.trace` work.
        self._set_cos_sin_cache(
Casper Hansen's avatar
Casper Hansen committed
75
76
77
            seq_len=max_position_embeddings,
            device=self.inv_freq.device,
            dtype=torch.get_default_dtype(),
Haotian Tang's avatar
Haotian Tang committed
78
79
80
81
        )

    def _set_cos_sin_cache(self, seq_len, device, dtype):
        self.max_seq_len_cached = seq_len
Casper Hansen's avatar
Casper Hansen committed
82
83
84
        t = torch.arange(
            self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype
        )
Haotian Tang's avatar
Haotian Tang committed
85
86
87
88

        freqs = torch.einsum("i,j->ij", t, self.inv_freq)
        # Different from paper, but it uses a different permutation in order to obtain the same calculation
        emb = torch.cat((freqs, freqs), dim=-1)
Casper Hansen's avatar
Casper Hansen committed
89

Haotian Tang's avatar
Haotian Tang committed
90
91
92
        cos = freqs.cos()
        sin = freqs.sin()
        cache = torch.cat((cos, sin), dim=-1)
Casper Hansen's avatar
Casper Hansen committed
93

Haotian Tang's avatar
Haotian Tang committed
94
        self.register_buffer("cos_sin_cache", cache.half(), persistent=False)
Casper Hansen's avatar
Casper Hansen committed
95

Haotian Tang's avatar
Haotian Tang committed
96
97
98
99
100
101
102
103
    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        positions: torch.Tensor,
    ):
        # Apply rotary embedding to the query and key before passing them
        # to the attention op.
Casper Hansen's avatar
Casper Hansen committed
104
        # print(positions.shape, query.shape, key.shape, self.cos_sin_cache.shape)
Haotian Tang's avatar
Haotian Tang committed
105
106
        query = query.contiguous()
        key = key.contiguous()
107
        awq_inference_engine.rotary_embedding_neox(
Haotian Tang's avatar
Haotian Tang committed
108
109
110
111
            positions,
            query,
            key,
            self.dim,
112
            self.cos_sin_cache
Haotian Tang's avatar
Haotian Tang committed
113
114
115
        )
        return query, key

Casper Hansen's avatar
Casper Hansen committed
116
class QuantAttentionFused(nn.Module):
Casper Hansen's avatar
Casper Hansen committed
117
    def __init__(self, hidden_size, num_heads, qkv_layer, o_proj, dev, max_seq_len, use_alibi=False, attention_shapes=None):
Casper Hansen's avatar
Casper Hansen committed
118
119
120
121
122
123
124
        super().__init__()
        self.hidden_size = hidden_size
        self.n_local_heads = num_heads
        self.head_dim = self.hidden_size // num_heads
        self.qkv_proj = qkv_layer
        self.o_proj = o_proj
        self.start_pos = 0
Casper Hansen's avatar
Casper Hansen committed
125
126
        self.use_alibi = use_alibi
        self.cache_batch_size = 1
Casper Hansen's avatar
Casper Hansen committed
127
        self.attention_shapes = attention_shapes if attention_shapes is not None else {
128
129
130
131
132
133
134
135
            # following fastertransformer definition
            "cache_v": (self.cache_batch_size, self.n_local_heads, max_seq_len, self.head_dim,),
            # 8: pack 8 fp16 in FT, if fp32 then use 4
            "cache_k": (self.cache_batch_size, self.n_local_heads, self.head_dim // 8, max_seq_len, 8,),
            "xqkv_view": (-1, self.n_local_heads, self.head_dim),
            "xq_slice": lambda xqkv: xqkv[:, :, 0],
            "xk_slice": lambda xqkv: xqkv[:, :, 1],
            "xv_slice": lambda xqkv: xqkv[:, :, 2],
Casper Hansen's avatar
Casper Hansen committed
136
            "xk_reshape": (self.n_local_heads, self.head_dim // 8, 8),
137
138
139
140
141
142
            "xk_view": (self.n_local_heads, self.head_dim),
            "xv_view": (self.n_local_heads, self.head_dim),
            "single_xq_view": (self.n_local_heads, self.head_dim),
            "single_xk_view": (self.n_local_heads, self.head_dim),
            "single_xv_view": (self.n_local_heads, self.head_dim)
        }
Casper Hansen's avatar
Casper Hansen committed
143

Casper Hansen's avatar
Casper Hansen committed
144
        self.cache_v = (
145
            torch.zeros(self.attention_shapes["cache_v"]).to(dev).half()
146
147
        )
        
Casper Hansen's avatar
Casper Hansen committed
148
        self.cache_k = (
149
            torch.zeros(self.attention_shapes["cache_k"]).to(dev).half()
150
        )
151

Casper Hansen's avatar
Casper Hansen committed
152
153
154
155
156
        if use_alibi:
            alibi_slopes, alibi_bias = build_alibi_bias(self.n_local_heads, max_seq_len)
            self.alibi_slopes = alibi_slopes.float().to(dev)
            self.alibi_bias = alibi_bias.float().to(dev)
            self.rotary_dim = 0
157
            self.is_neox = False
Casper Hansen's avatar
Casper Hansen committed
158
159
160
161
162
        else:
            self.freqs_cis = precompute_freqs_cis(
                hidden_size // num_heads,
                max_seq_len * 2,
            ).to(dev)
163
            self.rotary_dim = self.head_dim
Casper Hansen's avatar
Casper Hansen committed
164
            self.alibi_slopes = None
165
            self.is_neox = True
166
    
Casper Hansen's avatar
Casper Hansen committed
167
168
169
170
171
172
    def forward(
        self,
        hidden_states, past_key_value=None, attention_mask=None, position_ids=None, output_attentions=False, use_cache=False
    ):
        bsz, seqlen, _ = hidden_states.shape
        xqkv = self.qkv_proj(hidden_states)
173
174
175
176
        xqkv = xqkv.view((bsz, seqlen) + self.attention_shapes["xqkv_view"])
        xq = self.attention_shapes["xq_slice"](xqkv)
        xk = self.attention_shapes["xk_slice"](xqkv)
        xv = self.attention_shapes["xv_slice"](xqkv)
Haotian Tang's avatar
Haotian Tang committed
177

Casper Hansen's avatar
Casper Hansen committed
178
179
        if seqlen > 1:
            xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
180
181
            xk = xk.view((bsz, seqlen) + self.attention_shapes["xk_view"])
            xv = xv.view((bsz, seqlen) + self.attention_shapes["xv_view"])
Haotian Tang's avatar
Haotian Tang committed
182

183
184
            if not self.use_alibi:
                xq, xk = apply_rotary_emb(xq, xk, freqs_cis=self.freqs_cis[self.start_pos : self.start_pos + seqlen])
Haotian Tang's avatar
Haotian Tang committed
185

Casper Hansen's avatar
Casper Hansen committed
186
187
            self.cache_k = self.cache_k.to(xq)
            self.cache_v = self.cache_v.to(xq)
Haotian Tang's avatar
Haotian Tang committed
188

Casper Hansen's avatar
Casper Hansen committed
189
190
            values_store = xv.transpose(2, 1)
            keys_store = (
Casper Hansen's avatar
Casper Hansen committed
191
                xk.reshape((bsz, seqlen) + self.attention_shapes["xk_reshape"])
Casper Hansen's avatar
Casper Hansen committed
192
193
194
                .permute(0, 2, 3, 1, 4)
                .contiguous()
            )
Haotian Tang's avatar
Haotian Tang committed
195

Casper Hansen's avatar
Casper Hansen committed
196
197
198
            self.cache_v[:bsz, :, self.start_pos : self.start_pos + seqlen, :] = values_store
            self.cache_k[:bsz, :, :, self.start_pos : self.start_pos + seqlen, :] = keys_store

Casper Hansen's avatar
Casper Hansen committed
199
200
            keys = xk
            values = xv
Casper Hansen's avatar
Casper Hansen committed
201
            past_key_value = (xk, xv) if use_cache else None
Casper Hansen's avatar
Casper Hansen committed
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216

            xq = xq.transpose(1, 2)
            keys = keys.transpose(1, 2)
            values = values.transpose(1, 2)
            scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)

            if self.use_alibi:
                scores += self.alibi_bias[..., :seqlen]

            if attention_mask is not None:
                scores = scores + attention_mask  # (bs, n_local_heads, slen, cache_len + slen)
                
            scores = F.softmax(scores.float(), dim=-1).type_as(xq)
            output = torch.matmul(scores, values)  # (bs, n_local_heads, slen, head_dim)
            attention_weight = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
Casper Hansen's avatar
Casper Hansen committed
217
        else:
218
219
220
221
222
223
224
            # xq = xq[:, 0, :, :]
            # xk = xk[:, 0, :, :]
            # xv = xv[:, 0, :, :]
            xq = xq.view((bsz,) + self.attention_shapes["single_xq_view"])
            xk = xk.view((bsz,) + self.attention_shapes["single_xk_view"])
            xv = xv.view((bsz,) + self.attention_shapes["single_xv_view"])

Casper Hansen's avatar
Casper Hansen committed
225
            past_key_value = (xk, xv) if use_cache else None
Casper Hansen's avatar
Casper Hansen committed
226
            attention_weight = awq_inference_engine.single_query_attention(
Casper Hansen's avatar
Casper Hansen committed
227
228
229
230
231
232
233
234
235
236
                xq, # query
                xk, # key
                xv, # value
                self.cache_k, # key cache
                self.cache_v, # value cache
                None, # length per sample
                self.alibi_slopes, # alibi slopes
                self.start_pos, # timestep
                self.rotary_dim, # rotary embedding dimension
                10000, # rotary embedding base
237
                self.is_neox, # is neox
Casper Hansen's avatar
Casper Hansen committed
238
            )
Casper Hansen's avatar
Casper Hansen committed
239
            attention_weight = attention_weight.reshape(bsz, 1, -1)
Casper Hansen's avatar
Casper Hansen committed
240
        
Casper Hansen's avatar
Casper Hansen committed
241
        attn_output = self.o_proj(attention_weight)
Casper Hansen's avatar
Casper Hansen committed
242
243
244
245
246
        
        if use_cache:
            self.start_pos += seqlen
        else:
            self.start_pos = 0
Haotian Tang's avatar
Haotian Tang committed
247

Casper Hansen's avatar
Casper Hansen committed
248
        return attn_output, attention_weight, past_key_value