attn.py 8.12 KB
Newer Older
Casper Hansen's avatar
Casper Hansen committed
1
import math
Haotian Tang's avatar
Haotian Tang committed
2
3
4
import torch
import torch.nn as nn
import awq_inference_engine
Casper Hansen's avatar
Casper Hansen committed
5
from torch.nn import functional as F
Casper Hansen's avatar
Casper Hansen committed
6

Casper Hansen's avatar
Casper Hansen committed
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
    t = torch.arange(end, device=freqs.device)  # type: ignore
    freqs = torch.outer(t, freqs).float()  # type: ignore
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)  # complex64
    return freqs_cis

def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
    ndim = x.ndim
    assert 0 <= 1 < ndim
    assert freqs_cis.shape == (x.shape[1], x.shape[-1])
    shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
    return freqs_cis.view(*shape)

def apply_rotary_emb(
    xq: torch.Tensor,
    xk: torch.Tensor,
    freqs_cis: torch.Tensor,
):
    xq_ = torch.view_as_complex(
        xq.float().reshape(*xq.shape[:-1], 2, -1).transpose(-2, -1).contiguous()
    )
    xk_ = torch.view_as_complex(
        xk.float().reshape(*xk.shape[:-1], 2, -1).transpose(-2, -1).contiguous()
    )
    freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
    xq_out = torch.view_as_real(xq_ * freqs_cis).transpose(-2, -1).flatten(3)
    xk_out = torch.view_as_real(xk_ * freqs_cis).transpose(-2, -1).flatten(3)
    return xq_out.type_as(xq), xk_out.type_as(xk)

Haotian Tang's avatar
Haotian Tang committed
37
38
39
40
41
42
43
44

class QuantLlamaRotaryEmbedding(nn.Module):
    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
        super().__init__()

        self.dim = dim
        self.max_position_embeddings = max_position_embeddings
        self.base = base
Casper Hansen's avatar
Casper Hansen committed
45
46
47
        inv_freq = 1.0 / (
            self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)
        )
Haotian Tang's avatar
Haotian Tang committed
48
49
50
        self.register_buffer("inv_freq", inv_freq)
        # Build here to make `torch.jit.trace` work.
        self._set_cos_sin_cache(
Casper Hansen's avatar
Casper Hansen committed
51
52
53
            seq_len=max_position_embeddings,
            device=self.inv_freq.device,
            dtype=torch.get_default_dtype(),
Haotian Tang's avatar
Haotian Tang committed
54
55
56
57
        )

    def _set_cos_sin_cache(self, seq_len, device, dtype):
        self.max_seq_len_cached = seq_len
Casper Hansen's avatar
Casper Hansen committed
58
59
60
        t = torch.arange(
            self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype
        )
Haotian Tang's avatar
Haotian Tang committed
61
62
63
64

        freqs = torch.einsum("i,j->ij", t, self.inv_freq)
        # Different from paper, but it uses a different permutation in order to obtain the same calculation
        emb = torch.cat((freqs, freqs), dim=-1)
Casper Hansen's avatar
Casper Hansen committed
65

Haotian Tang's avatar
Haotian Tang committed
66
67
68
        cos = freqs.cos()
        sin = freqs.sin()
        cache = torch.cat((cos, sin), dim=-1)
Casper Hansen's avatar
Casper Hansen committed
69

Haotian Tang's avatar
Haotian Tang committed
70
        self.register_buffer("cos_sin_cache", cache.half(), persistent=False)
Casper Hansen's avatar
Casper Hansen committed
71

Haotian Tang's avatar
Haotian Tang committed
72
73
74
75
76
77
78
79
    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        positions: torch.Tensor,
    ):
        # Apply rotary embedding to the query and key before passing them
        # to the attention op.
Casper Hansen's avatar
Casper Hansen committed
80
        # print(positions.shape, query.shape, key.shape, self.cos_sin_cache.shape)
Haotian Tang's avatar
Haotian Tang committed
81
82
        query = query.contiguous()
        key = key.contiguous()
83
        awq_inference_engine.rotary_embedding_neox(
Haotian Tang's avatar
Haotian Tang committed
84
85
86
87
            positions,
            query,
            key,
            self.dim,
88
            self.cos_sin_cache
Haotian Tang's avatar
Haotian Tang committed
89
90
91
        )
        return query, key

Casper Hansen's avatar
Casper Hansen committed
92
93
94
95
96
97
98
99
100
class QuantLlamaAttentionFused(nn.Module):
    def __init__(self, hidden_size, num_heads, qkv_layer, o_proj, dev, max_position_embeddings):
        super().__init__()
        self.hidden_size = hidden_size
        self.n_local_heads = num_heads
        self.head_dim = self.hidden_size // num_heads
        self.qkv_proj = qkv_layer
        self.o_proj = o_proj
        self.start_pos = 0
101
        self.use_sdpa_torch = False
Casper Hansen's avatar
Casper Hansen committed
102

Casper Hansen's avatar
Casper Hansen committed
103
104
105
        # following fastertransformer definition
        self.cache_v = (
            torch.zeros(
106
107
108
109
                ( 1, self.n_local_heads, max_position_embeddings, self.head_dim, )
            ).to(dev).half()
        )
        
Casper Hansen's avatar
Casper Hansen committed
110
111
112
        # 8: pack 8 fp16 in FT, if fp32 then use 4
        self.cache_k = (
            torch.zeros(
113
114
115
                ( 1, self.n_local_heads, self.head_dim // 8, max_position_embeddings, 8, )
            ).to(dev).half()
        )
Casper Hansen's avatar
Casper Hansen committed
116
117
118
119
        self.freqs_cis = precompute_freqs_cis(
            hidden_size // num_heads,
            max_position_embeddings * 2,
        ).to(dev)
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
    
    def _multi_query_attention_torch(self, query, key, value, batch_size, seqlen, use_cache, past_key_value):
            # faster prompt processing
            query = query.view(batch_size, seqlen, self.n_local_heads, self.head_dim).transpose(1, 2)
            key = key.view(batch_size, seqlen, self.n_local_heads, self.head_dim).transpose(1, 2)
            value = value.view(batch_size, seqlen, self.n_local_heads, self.head_dim).transpose(1, 2)

            if use_cache:
                key = key.contiguous()
                value = value.contiguous()
                query = query.contiguous()

            output = F.scaled_dot_product_attention(query, key, value, is_causal=past_key_value is None)
            
            del query, key, value

            output = output.transpose(1, 2).reshape(batch_size, seqlen, self.hidden_size)

            return output
    
Casper Hansen's avatar
Casper Hansen committed
140
141
142
143
144
145
146
147
148
149
    def forward(
        self,
        hidden_states, past_key_value=None, attention_mask=None, position_ids=None, output_attentions=False, use_cache=False
    ):
        bsz, seqlen, _ = hidden_states.shape
        xqkv = self.qkv_proj(hidden_states)
        xqkv = xqkv.view(bsz, seqlen, -1, self.n_local_heads, self.head_dim)
        xq = xqkv[:, :, 0]
        xk = xqkv[:, :, 1]
        xv = xqkv[:, :, 2]
Haotian Tang's avatar
Haotian Tang committed
150

Casper Hansen's avatar
Casper Hansen committed
151
152
153
154
        if seqlen > 1:
            xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
            xk = xk.view(bsz, seqlen, self.n_local_heads, self.head_dim)
            xv = xv.view(bsz, seqlen, self.n_local_heads, self.head_dim)
Haotian Tang's avatar
Haotian Tang committed
155

Casper Hansen's avatar
Casper Hansen committed
156
            xq, xk = apply_rotary_emb(xq, xk, freqs_cis=self.freqs_cis[self.start_pos : self.start_pos + seqlen])
Haotian Tang's avatar
Haotian Tang committed
157

Casper Hansen's avatar
Casper Hansen committed
158
159
            self.cache_k = self.cache_k.to(xq)
            self.cache_v = self.cache_v.to(xq)
Haotian Tang's avatar
Haotian Tang committed
160

Casper Hansen's avatar
Casper Hansen committed
161
162
163
164
165
166
            values_store = xv.transpose(2, 1)
            keys_store = (
                xk.reshape(bsz, seqlen, self.n_local_heads, self.head_dim // 8, 8)
                .permute(0, 2, 3, 1, 4)
                .contiguous()
            )
Haotian Tang's avatar
Haotian Tang committed
167

Casper Hansen's avatar
Casper Hansen committed
168
169
170
171
172
173
174
            self.cache_v[:bsz, :, self.start_pos : self.start_pos + seqlen, :] = values_store
            self.cache_k[:bsz, :, :, self.start_pos : self.start_pos + seqlen, :] = keys_store

            keys = xk
            values = xv
            past_key_value = (xk, xv) if use_cache else None

175
176
177
178
179
180
181
182
183
184
185
186
            if self.use_sdpa_torch:
                output = self._multi_query_attention_torch(xq, xk, xv, bsz, seqlen, True, past_key_value)
            else:
                xq = xq.transpose(1, 2)
                keys = keys.transpose(1, 2)
                values = values.transpose(1, 2)
                scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
                if attention_mask is not None:
                    scores = scores + attention_mask  # (bs, n_local_heads, slen, cache_len + slen)
                scores = F.softmax(scores.float(), dim=-1).type_as(xq)
                output = torch.matmul(scores, values)  # (bs, n_local_heads, slen, head_dim)
                output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
Casper Hansen's avatar
Casper Hansen committed
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
        else:
            xq = xq[:, 0, :, :]
            xk = xk[:, 0, :, :]
            xv = xv[:, 0, :, :]
            past_key_value = (xk, xv) if use_cache else None
            output = awq_inference_engine.single_query_attention(
                xq,
                xk,
                xv,
                self.cache_k,
                self.cache_v,
                None,
                None,
                self.start_pos,
                self.head_dim,
                10000,
                True,
            )
            output = output.reshape(bsz, 1, -1)
        
        attn_output = self.o_proj(output)
        
        if use_cache:
            self.start_pos += seqlen
        else:
            self.start_pos = 0
Haotian Tang's avatar
Haotian Tang committed
213
214

        return attn_output, None, past_key_value