attn.py 7.02 KB
Newer Older
Casper Hansen's avatar
Casper Hansen committed
1
import math
Haotian Tang's avatar
Haotian Tang committed
2
3
4
import torch
import torch.nn as nn
import awq_inference_engine
Casper Hansen's avatar
Casper Hansen committed
5
from torch.nn import functional as F
Casper Hansen's avatar
Casper Hansen committed
6

Casper Hansen's avatar
Casper Hansen committed
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
    t = torch.arange(end, device=freqs.device)  # type: ignore
    freqs = torch.outer(t, freqs).float()  # type: ignore
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)  # complex64
    return freqs_cis

def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
    ndim = x.ndim
    assert 0 <= 1 < ndim
    assert freqs_cis.shape == (x.shape[1], x.shape[-1])
    shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
    return freqs_cis.view(*shape)

def apply_rotary_emb(
    xq: torch.Tensor,
    xk: torch.Tensor,
    freqs_cis: torch.Tensor,
):
    xq_ = torch.view_as_complex(
        xq.float().reshape(*xq.shape[:-1], 2, -1).transpose(-2, -1).contiguous()
    )
    xk_ = torch.view_as_complex(
        xk.float().reshape(*xk.shape[:-1], 2, -1).transpose(-2, -1).contiguous()
    )
    freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
    xq_out = torch.view_as_real(xq_ * freqs_cis).transpose(-2, -1).flatten(3)
    xk_out = torch.view_as_real(xk_ * freqs_cis).transpose(-2, -1).flatten(3)
    return xq_out.type_as(xq), xk_out.type_as(xk)

Haotian Tang's avatar
Haotian Tang committed
37
38
39
40
41
42
43
44

class QuantLlamaRotaryEmbedding(nn.Module):
    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
        super().__init__()

        self.dim = dim
        self.max_position_embeddings = max_position_embeddings
        self.base = base
Casper Hansen's avatar
Casper Hansen committed
45
46
47
        inv_freq = 1.0 / (
            self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)
        )
Haotian Tang's avatar
Haotian Tang committed
48
49
50
        self.register_buffer("inv_freq", inv_freq)
        # Build here to make `torch.jit.trace` work.
        self._set_cos_sin_cache(
Casper Hansen's avatar
Casper Hansen committed
51
52
53
            seq_len=max_position_embeddings,
            device=self.inv_freq.device,
            dtype=torch.get_default_dtype(),
Haotian Tang's avatar
Haotian Tang committed
54
55
56
57
        )

    def _set_cos_sin_cache(self, seq_len, device, dtype):
        self.max_seq_len_cached = seq_len
Casper Hansen's avatar
Casper Hansen committed
58
59
60
        t = torch.arange(
            self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype
        )
Haotian Tang's avatar
Haotian Tang committed
61
62
63
64

        freqs = torch.einsum("i,j->ij", t, self.inv_freq)
        # Different from paper, but it uses a different permutation in order to obtain the same calculation
        emb = torch.cat((freqs, freqs), dim=-1)
Casper Hansen's avatar
Casper Hansen committed
65

Haotian Tang's avatar
Haotian Tang committed
66
67
68
        cos = freqs.cos()
        sin = freqs.sin()
        cache = torch.cat((cos, sin), dim=-1)
Casper Hansen's avatar
Casper Hansen committed
69

Haotian Tang's avatar
Haotian Tang committed
70
        self.register_buffer("cos_sin_cache", cache.half(), persistent=False)
Casper Hansen's avatar
Casper Hansen committed
71

Haotian Tang's avatar
Haotian Tang committed
72
73
74
75
76
77
78
79
    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        positions: torch.Tensor,
    ):
        # Apply rotary embedding to the query and key before passing them
        # to the attention op.
Casper Hansen's avatar
Casper Hansen committed
80
        # print(positions.shape, query.shape, key.shape, self.cos_sin_cache.shape)
Haotian Tang's avatar
Haotian Tang committed
81
82
        query = query.contiguous()
        key = key.contiguous()
83
        awq_inference_engine.rotary_embedding_neox(
Haotian Tang's avatar
Haotian Tang committed
84
85
86
87
            positions,
            query,
            key,
            self.dim,
88
            self.cos_sin_cache
Haotian Tang's avatar
Haotian Tang committed
89
90
91
        )
        return query, key

Casper Hansen's avatar
Casper Hansen committed
92
93
94
95
96
97
98
99
100
class QuantLlamaAttentionFused(nn.Module):
    def __init__(self, hidden_size, num_heads, qkv_layer, o_proj, dev, max_position_embeddings):
        super().__init__()
        self.hidden_size = hidden_size
        self.n_local_heads = num_heads
        self.head_dim = self.hidden_size // num_heads
        self.qkv_proj = qkv_layer
        self.o_proj = o_proj
        self.start_pos = 0
Casper Hansen's avatar
Casper Hansen committed
101

Casper Hansen's avatar
Casper Hansen committed
102
103
104
        # following fastertransformer definition
        self.cache_v = (
            torch.zeros(
105
106
107
108
                ( 1, self.n_local_heads, max_position_embeddings, self.head_dim, )
            ).to(dev).half()
        )
        
Casper Hansen's avatar
Casper Hansen committed
109
110
111
        # 8: pack 8 fp16 in FT, if fp32 then use 4
        self.cache_k = (
            torch.zeros(
112
113
114
                ( 1, self.n_local_heads, self.head_dim // 8, max_position_embeddings, 8, )
            ).to(dev).half()
        )
Casper Hansen's avatar
Casper Hansen committed
115
116
117
118
        self.freqs_cis = precompute_freqs_cis(
            hidden_size // num_heads,
            max_position_embeddings * 2,
        ).to(dev)
Haotian Tang's avatar
Haotian Tang committed
119
        
Casper Hansen's avatar
Casper Hansen committed
120
121
122
123
124
125
126
127
128
129
    def forward(
        self,
        hidden_states, past_key_value=None, attention_mask=None, position_ids=None, output_attentions=False, use_cache=False
    ):
        bsz, seqlen, _ = hidden_states.shape
        xqkv = self.qkv_proj(hidden_states)
        xqkv = xqkv.view(bsz, seqlen, -1, self.n_local_heads, self.head_dim)
        xq = xqkv[:, :, 0]
        xk = xqkv[:, :, 1]
        xv = xqkv[:, :, 2]
Haotian Tang's avatar
Haotian Tang committed
130

Casper Hansen's avatar
Casper Hansen committed
131
132
133
134
        if seqlen > 1:
            xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
            xk = xk.view(bsz, seqlen, self.n_local_heads, self.head_dim)
            xv = xv.view(bsz, seqlen, self.n_local_heads, self.head_dim)
Haotian Tang's avatar
Haotian Tang committed
135

Casper Hansen's avatar
Casper Hansen committed
136
            xq, xk = apply_rotary_emb(xq, xk, freqs_cis=self.freqs_cis[self.start_pos : self.start_pos + seqlen])
Haotian Tang's avatar
Haotian Tang committed
137

Casper Hansen's avatar
Casper Hansen committed
138
139
            self.cache_k = self.cache_k.to(xq)
            self.cache_v = self.cache_v.to(xq)
Haotian Tang's avatar
Haotian Tang committed
140

Casper Hansen's avatar
Casper Hansen committed
141
142
143
144
145
146
            values_store = xv.transpose(2, 1)
            keys_store = (
                xk.reshape(bsz, seqlen, self.n_local_heads, self.head_dim // 8, 8)
                .permute(0, 2, 3, 1, 4)
                .contiguous()
            )
Haotian Tang's avatar
Haotian Tang committed
147

Casper Hansen's avatar
Casper Hansen committed
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
            self.cache_v[:bsz, :, self.start_pos : self.start_pos + seqlen, :] = values_store
            self.cache_k[:bsz, :, :, self.start_pos : self.start_pos + seqlen, :] = keys_store

            keys = xk
            values = xv
            past_key_value = (xk, xv) if use_cache else None

            xq = xq.transpose(1, 2)
            keys = keys.transpose(1, 2)
            values = values.transpose(1, 2)
            scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
            if attention_mask is not None:
                scores = scores + attention_mask  # (bs, n_local_heads, slen, cache_len + slen)
            scores = F.softmax(scores.float(), dim=-1).type_as(xq)
            output = torch.matmul(scores, values)  # (bs, n_local_heads, slen, head_dim)
            output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
        else:
            xq = xq[:, 0, :, :]
            xk = xk[:, 0, :, :]
            xv = xv[:, 0, :, :]
            past_key_value = (xk, xv) if use_cache else None
            output = awq_inference_engine.single_query_attention(
                xq,
                xk,
                xv,
                self.cache_k,
                self.cache_v,
                None,
                None,
                self.start_pos,
                self.head_dim,
                10000,
                True,
            )
            output = output.reshape(bsz, 1, -1)
        
        attn_output = self.o_proj(output)
        
        if use_cache:
            self.start_pos += seqlen
        else:
            self.start_pos = 0
Haotian Tang's avatar
Haotian Tang committed
190
191

        return attn_output, None, past_key_value