attn.py 11.7 KB
Newer Older
1
import os
Casper Hansen's avatar
Casper Hansen committed
2
import math
Haotian Tang's avatar
Haotian Tang committed
3
4
5
import torch
import torch.nn as nn
import awq_inference_engine
Casper Hansen's avatar
Casper Hansen committed
6
from torch.nn import functional as F
Casper Hansen's avatar
Casper Hansen committed
7

Casper Hansen's avatar
Casper Hansen committed
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
    t = torch.arange(end, device=freqs.device)  # type: ignore
    freqs = torch.outer(t, freqs).float()  # type: ignore
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)  # complex64
    return freqs_cis

def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
    ndim = x.ndim
    assert 0 <= 1 < ndim
    assert freqs_cis.shape == (x.shape[1], x.shape[-1])
    shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
    return freqs_cis.view(*shape)

def apply_rotary_emb(
    xq: torch.Tensor,
    xk: torch.Tensor,
    freqs_cis: torch.Tensor,
):
    xq_ = torch.view_as_complex(
        xq.float().reshape(*xq.shape[:-1], 2, -1).transpose(-2, -1).contiguous()
    )
    xk_ = torch.view_as_complex(
        xk.float().reshape(*xk.shape[:-1], 2, -1).transpose(-2, -1).contiguous()
    )
    freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
    xq_out = torch.view_as_real(xq_ * freqs_cis).transpose(-2, -1).flatten(3)
    xk_out = torch.view_as_real(xk_ * freqs_cis).transpose(-2, -1).flatten(3)
    return xq_out.type_as(xq), xk_out.type_as(xk)

Casper Hansen's avatar
Casper Hansen committed
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
def gen_slopes(n_heads, alibi_bias_max=8):
    _n_heads = 2 ** math.ceil(math.log2(n_heads))
    m = torch.arange(1, _n_heads + 1, dtype=torch.float32)
    m = m.mul(alibi_bias_max / _n_heads)
    slopes = 1.0 / torch.pow(2, m)
    if _n_heads != n_heads:
        slopes = torch.concat([slopes[1::2], slopes[::2]])[:n_heads]
    return slopes.view(1, n_heads, 1, 1)


def build_alibi_bias(
    n_heads, seq_len, full=False, alibi_bias_max=8, dtype=torch.float32
):
    alibi_bias = torch.arange(1 - seq_len, 1, dtype=torch.int32).view(1, 1, 1, seq_len)
    if full:
        alibi_bias = alibi_bias - torch.arange(1 - seq_len, 1, dtype=torch.int32).view(
            1, 1, seq_len, 1
        )
        alibi_bias = alibi_bias.abs().mul(-1)
    slopes = gen_slopes(n_heads, alibi_bias_max)
    alibi_bias = alibi_bias * slopes
    slopes = slopes.squeeze(0).squeeze(-1).squeeze(-1)
    return slopes.to(dtype=dtype), alibi_bias.to(dtype=dtype)

Haotian Tang's avatar
Haotian Tang committed
62
63
64
65
66
67
68
69

class QuantLlamaRotaryEmbedding(nn.Module):
    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
        super().__init__()

        self.dim = dim
        self.max_position_embeddings = max_position_embeddings
        self.base = base
Casper Hansen's avatar
Casper Hansen committed
70
71
72
        inv_freq = 1.0 / (
            self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)
        )
Haotian Tang's avatar
Haotian Tang committed
73
74
75
        self.register_buffer("inv_freq", inv_freq)
        # Build here to make `torch.jit.trace` work.
        self._set_cos_sin_cache(
Casper Hansen's avatar
Casper Hansen committed
76
77
78
            seq_len=max_position_embeddings,
            device=self.inv_freq.device,
            dtype=torch.get_default_dtype(),
Haotian Tang's avatar
Haotian Tang committed
79
80
81
82
        )

    def _set_cos_sin_cache(self, seq_len, device, dtype):
        self.max_seq_len_cached = seq_len
Casper Hansen's avatar
Casper Hansen committed
83
84
85
        t = torch.arange(
            self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype
        )
Haotian Tang's avatar
Haotian Tang committed
86
87
88
89

        freqs = torch.einsum("i,j->ij", t, self.inv_freq)
        # Different from paper, but it uses a different permutation in order to obtain the same calculation
        emb = torch.cat((freqs, freqs), dim=-1)
Casper Hansen's avatar
Casper Hansen committed
90

Haotian Tang's avatar
Haotian Tang committed
91
92
93
        cos = freqs.cos()
        sin = freqs.sin()
        cache = torch.cat((cos, sin), dim=-1)
Casper Hansen's avatar
Casper Hansen committed
94

Haotian Tang's avatar
Haotian Tang committed
95
        self.register_buffer("cos_sin_cache", cache.half(), persistent=False)
Casper Hansen's avatar
Casper Hansen committed
96

Haotian Tang's avatar
Haotian Tang committed
97
98
99
100
101
102
103
104
    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        positions: torch.Tensor,
    ):
        # Apply rotary embedding to the query and key before passing them
        # to the attention op.
Casper Hansen's avatar
Casper Hansen committed
105
        # print(positions.shape, query.shape, key.shape, self.cos_sin_cache.shape)
Haotian Tang's avatar
Haotian Tang committed
106
107
        query = query.contiguous()
        key = key.contiguous()
108
        awq_inference_engine.rotary_embedding_neox(
Haotian Tang's avatar
Haotian Tang committed
109
110
111
112
            positions,
            query,
            key,
            self.dim,
113
            self.cos_sin_cache
Haotian Tang's avatar
Haotian Tang committed
114
115
116
        )
        return query, key

Casper Hansen's avatar
Casper Hansen committed
117
class QuantAttentionFused(nn.Module):
Casper Hansen's avatar
Casper Hansen committed
118
    def __init__(self, hidden_size, n_heads, n_kv_heads, qkv_layer, o_proj, dev, max_seq_len, 
119
                       use_alibi=False, attention_shapes=None):
Casper Hansen's avatar
Casper Hansen committed
120
121
        super().__init__()
        self.hidden_size = hidden_size
Casper Hansen's avatar
Casper Hansen committed
122
123
124
        self.n_heads = n_heads
        self.n_kv_heads = n_kv_heads
        self.head_dim = self.hidden_size // n_heads
Casper Hansen's avatar
Casper Hansen committed
125
126
127
        self.qkv_proj = qkv_layer
        self.o_proj = o_proj
        self.start_pos = 0
Casper Hansen's avatar
Casper Hansen committed
128
        self.use_alibi = use_alibi
129
        self.cache_batch_size = int(os.getenv("AWQ_BATCH_SIZE", "1"))
Casper Hansen's avatar
Casper Hansen committed
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172

        if attention_shapes is not None:
            self.attention_shapes = attention_shapes

        elif self.n_kv_heads == 0:
            self.attention_shapes = {
                # following fastertransformer definition
                "cache_v": (self.cache_batch_size, self.n_heads, max_seq_len, self.head_dim,),
                # 8: pack 8 fp16 in FT, if fp32 then use 4
                "cache_k": (self.cache_batch_size, self.n_heads, self.head_dim // 8, max_seq_len, 8,),
                "xqkv_view": (-1, self.n_heads, self.head_dim),
                "xq_slice": lambda xqkv: xqkv[:, :, 0],
                "xk_slice": lambda xqkv: xqkv[:, :, 1],
                "xv_slice": lambda xqkv: xqkv[:, :, 2],
                "xq_view": (self.n_heads, self.head_dim),
                "xk_view": (self.n_heads, self.head_dim),
                "xv_view": (self.n_heads, self.head_dim),
                "xk_reshape": (self.n_heads, self.head_dim // 8, 8),
                "single_xq_view": (self.n_heads, self.head_dim),
                "single_xk_view": (self.n_heads, self.head_dim),
                "single_xv_view": (self.n_heads, self.head_dim)
            }

        else:
            self.attention_shapes = {
                # following fastertransformer definition
                "cache_v": (self.cache_batch_size, self.n_kv_heads, max_seq_len, self.head_dim,),
                # 8: pack 8 fp16 in FT, if fp32 then use 4
                "cache_k": (self.cache_batch_size, self.n_kv_heads, self.head_dim // 8, max_seq_len, 8,),
                "xqkv_view": (self.n_heads + self.n_kv_heads * 2, self.head_dim),
                "xq_slice": lambda xqkv: xqkv[:, :, 0 : self.n_kv_heads],
                "xk_slice": lambda xqkv: xqkv[:, :, self.n_heads : (self.n_heads + self.n_kv_heads)],
                "xv_slice": lambda xqkv: xqkv[:, :, -self.n_kv_heads :],
                "xq_view": (self.n_kv_heads, self.head_dim),
                "xk_view": (self.n_kv_heads, self.head_dim),
                "xv_view": (self.n_kv_heads, self.head_dim),
                "xk_reshape": (self.n_kv_heads, self.head_dim // 8, 8),
                "single_xq_view": (self.n_kv_heads, self.head_dim),
                "single_xk_view": (self.n_kv_heads, self.head_dim),
                "single_xv_view": (self.n_kv_heads, self.head_dim)
            }
        
        print(self.attention_shapes)
Casper Hansen's avatar
Casper Hansen committed
173

Casper Hansen's avatar
Casper Hansen committed
174
        self.cache_v = (
175
            torch.zeros(self.attention_shapes["cache_v"]).to(dev).half()
176
177
        )
        
Casper Hansen's avatar
Casper Hansen committed
178
        self.cache_k = (
179
            torch.zeros(self.attention_shapes["cache_k"]).to(dev).half()
180
        )
181

Casper Hansen's avatar
Casper Hansen committed
182
        if use_alibi:
Casper Hansen's avatar
Casper Hansen committed
183
            alibi_slopes, alibi_bias = build_alibi_bias(self.n_heads, max_seq_len)
Casper Hansen's avatar
Casper Hansen committed
184
185
186
            self.alibi_slopes = alibi_slopes.float().to(dev)
            self.alibi_bias = alibi_bias.float().to(dev)
            self.rotary_dim = 0
187
            self.is_neox = False
Casper Hansen's avatar
Casper Hansen committed
188
189
        else:
            self.freqs_cis = precompute_freqs_cis(
Casper Hansen's avatar
Casper Hansen committed
190
                hidden_size // n_heads,
Casper Hansen's avatar
Casper Hansen committed
191
192
                max_seq_len * 2,
            ).to(dev)
193
            self.rotary_dim = self.head_dim
Casper Hansen's avatar
Casper Hansen committed
194
            self.alibi_slopes = None
195
            self.is_neox = True
196
    
Casper Hansen's avatar
Casper Hansen committed
197
198
199
200
201
    def forward(
        self,
        hidden_states, past_key_value=None, attention_mask=None, position_ids=None, output_attentions=False, use_cache=False
    ):
        bsz, seqlen, _ = hidden_states.shape
202
203
204
205
206
        if bsz != self.cache_batch_size:
            raise RuntimeError(
                f"Batch size is incorrectly set - input batch size {bsz}, kv-cache batch size {self.cache_batch_size}. "
                f"Use: AutoAWQForCausalLM.from_quantized(batch_size={bsz})"
            )
Casper Hansen's avatar
Casper Hansen committed
207
        xqkv = self.qkv_proj(hidden_states)
208
        xqkv = xqkv.view((bsz, seqlen) + self.attention_shapes["xqkv_view"])
Casper Hansen's avatar
Casper Hansen committed
209
        
210
211
212
        xq = self.attention_shapes["xq_slice"](xqkv)
        xk = self.attention_shapes["xk_slice"](xqkv)
        xv = self.attention_shapes["xv_slice"](xqkv)
Haotian Tang's avatar
Haotian Tang committed
213

Casper Hansen's avatar
Casper Hansen committed
214
        if seqlen > 1:
Casper Hansen's avatar
Casper Hansen committed
215
            xq = xq.view((bsz, seqlen) + self.attention_shapes["xq_view"])
216
217
            xk = xk.view((bsz, seqlen) + self.attention_shapes["xk_view"])
            xv = xv.view((bsz, seqlen) + self.attention_shapes["xv_view"])
Haotian Tang's avatar
Haotian Tang committed
218

219
220
            if not self.use_alibi:
                xq, xk = apply_rotary_emb(xq, xk, freqs_cis=self.freqs_cis[self.start_pos : self.start_pos + seqlen])
Haotian Tang's avatar
Haotian Tang committed
221

Casper Hansen's avatar
Casper Hansen committed
222
223
            self.cache_k = self.cache_k.to(xq)
            self.cache_v = self.cache_v.to(xq)
Haotian Tang's avatar
Haotian Tang committed
224

Casper Hansen's avatar
Casper Hansen committed
225
226
            values_store = xv.transpose(2, 1)
            keys_store = (
Casper Hansen's avatar
Casper Hansen committed
227
                xk.reshape((bsz, seqlen) + self.attention_shapes["xk_reshape"])
Casper Hansen's avatar
Casper Hansen committed
228
229
230
                .permute(0, 2, 3, 1, 4)
                .contiguous()
            )
Haotian Tang's avatar
Haotian Tang committed
231

Casper Hansen's avatar
Casper Hansen committed
232
233
234
            self.cache_v[:bsz, :, self.start_pos : self.start_pos + seqlen, :] = values_store
            self.cache_k[:bsz, :, :, self.start_pos : self.start_pos + seqlen, :] = keys_store

Casper Hansen's avatar
Casper Hansen committed
235
236
            keys = xk
            values = xv
Casper Hansen's avatar
Casper Hansen committed
237
            past_key_value = (xk, xv) if use_cache else None
Casper Hansen's avatar
Casper Hansen committed
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252

            xq = xq.transpose(1, 2)
            keys = keys.transpose(1, 2)
            values = values.transpose(1, 2)
            scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)

            if self.use_alibi:
                scores += self.alibi_bias[..., :seqlen]

            if attention_mask is not None:
                scores = scores + attention_mask  # (bs, n_local_heads, slen, cache_len + slen)
                
            scores = F.softmax(scores.float(), dim=-1).type_as(xq)
            output = torch.matmul(scores, values)  # (bs, n_local_heads, slen, head_dim)
            attention_weight = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
Casper Hansen's avatar
Casper Hansen committed
253
        else:
254
255
256
257
258
259
260
            # xq = xq[:, 0, :, :]
            # xk = xk[:, 0, :, :]
            # xv = xv[:, 0, :, :]
            xq = xq.view((bsz,) + self.attention_shapes["single_xq_view"])
            xk = xk.view((bsz,) + self.attention_shapes["single_xk_view"])
            xv = xv.view((bsz,) + self.attention_shapes["single_xv_view"])

Casper Hansen's avatar
Casper Hansen committed
261
            past_key_value = (xk, xv) if use_cache else None
Casper Hansen's avatar
Casper Hansen committed
262
            attention_weight = awq_inference_engine.single_query_attention(
Casper Hansen's avatar
Casper Hansen committed
263
264
265
266
267
268
269
270
271
272
                xq, # query
                xk, # key
                xv, # value
                self.cache_k, # key cache
                self.cache_v, # value cache
                None, # length per sample
                self.alibi_slopes, # alibi slopes
                self.start_pos, # timestep
                self.rotary_dim, # rotary embedding dimension
                10000, # rotary embedding base
273
                self.is_neox, # is neox
Casper Hansen's avatar
Casper Hansen committed
274
            )
Casper Hansen's avatar
Casper Hansen committed
275
            attention_weight = attention_weight.reshape(bsz, 1, -1)
Casper Hansen's avatar
Casper Hansen committed
276
        
Casper Hansen's avatar
Casper Hansen committed
277
        attn_output = self.o_proj(attention_weight)
Casper Hansen's avatar
Casper Hansen committed
278
279
280
281
282
        
        if use_cache:
            self.start_pos += seqlen
        else:
            self.start_pos = 0
Haotian Tang's avatar
Haotian Tang committed
283

Casper Hansen's avatar
Casper Hansen committed
284
        return attn_output, attention_weight, past_key_value