attn.py 5.96 KB
Newer Older
Casper Hansen's avatar
Casper Hansen committed
1
import math
Haotian Tang's avatar
Haotian Tang committed
2
3
4
import torch
import torch.nn as nn
import awq_inference_engine
Casper Hansen's avatar
Casper Hansen committed
5
from torch.nn import functional as F
Casper Hansen's avatar
Casper Hansen committed
6

Haotian Tang's avatar
Haotian Tang committed
7
8
9
10
11
12
13
14

class QuantLlamaRotaryEmbedding(nn.Module):
    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
        super().__init__()

        self.dim = dim
        self.max_position_embeddings = max_position_embeddings
        self.base = base
Casper Hansen's avatar
Casper Hansen committed
15
16
17
        inv_freq = 1.0 / (
            self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)
        )
Haotian Tang's avatar
Haotian Tang committed
18
19
20
        self.register_buffer("inv_freq", inv_freq)
        # Build here to make `torch.jit.trace` work.
        self._set_cos_sin_cache(
Casper Hansen's avatar
Casper Hansen committed
21
22
23
            seq_len=max_position_embeddings,
            device=self.inv_freq.device,
            dtype=torch.get_default_dtype(),
Haotian Tang's avatar
Haotian Tang committed
24
25
26
27
        )

    def _set_cos_sin_cache(self, seq_len, device, dtype):
        self.max_seq_len_cached = seq_len
Casper Hansen's avatar
Casper Hansen committed
28
29
30
        t = torch.arange(
            self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype
        )
Haotian Tang's avatar
Haotian Tang committed
31
32
33
34

        freqs = torch.einsum("i,j->ij", t, self.inv_freq)
        # Different from paper, but it uses a different permutation in order to obtain the same calculation
        emb = torch.cat((freqs, freqs), dim=-1)
Casper Hansen's avatar
Casper Hansen committed
35

Haotian Tang's avatar
Haotian Tang committed
36
37
38
        cos = freqs.cos()
        sin = freqs.sin()
        cache = torch.cat((cos, sin), dim=-1)
Casper Hansen's avatar
Casper Hansen committed
39
40
41

        # self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
        # self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
Haotian Tang's avatar
Haotian Tang committed
42
        self.register_buffer("cos_sin_cache", cache.half(), persistent=False)
Casper Hansen's avatar
Casper Hansen committed
43

Haotian Tang's avatar
Haotian Tang committed
44
45
46
47
48
49
50
51
    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        positions: torch.Tensor,
    ):
        # Apply rotary embedding to the query and key before passing them
        # to the attention op.
Casper Hansen's avatar
Casper Hansen committed
52
        # print(positions.shape, query.shape, key.shape, self.cos_sin_cache.shape)
Haotian Tang's avatar
Haotian Tang committed
53
54
        query = query.contiguous()
        key = key.contiguous()
55
        awq_inference_engine.rotary_embedding_neox(
Haotian Tang's avatar
Haotian Tang committed
56
57
58
59
            positions,
            query,
            key,
            self.dim,
60
            self.cos_sin_cache
Haotian Tang's avatar
Haotian Tang committed
61
62
63
        )
        return query, key

Casper Hansen's avatar
Casper Hansen committed
64
65
66
67
68
69
70
71
72
73

class QuantLlamaAttentionFused(nn.Module):
    def __init__(self, hidden_size, num_heads, qkv_layer, o_proj, dev, max_position_embeddings):
        super().__init__()
        self.hidden_size = hidden_size
        self.n_local_heads = num_heads
        self.head_dim = self.hidden_size // num_heads
        self.qkv_proj = qkv_layer
        self.o_proj = o_proj
        self.start_pos = 0
Casper Hansen's avatar
Casper Hansen committed
74

Casper Hansen's avatar
Casper Hansen committed
75
76
77
        # following fastertransformer definition
        self.cache_v = (
            torch.zeros(
78
79
80
81
                ( 1, self.n_local_heads, max_position_embeddings, self.head_dim, )
            ).to(dev).half()
        )
        
Casper Hansen's avatar
Casper Hansen committed
82
83
84
        # 8: pack 8 fp16 in FT, if fp32 then use 4
        self.cache_k = (
            torch.zeros(
85
86
87
                ( 1, self.n_local_heads, self.head_dim // 8, max_position_embeddings, 8, )
            ).to(dev).half()
        )
Casper Hansen's avatar
Casper Hansen committed
88
89

        self.rotary_emb = QuantLlamaRotaryEmbedding(
90
91
92
            dim=hidden_size // num_heads,
            max_position_embeddings=max_position_embeddings,
            device=dev
Casper Hansen's avatar
Casper Hansen committed
93
        )
Haotian Tang's avatar
Haotian Tang committed
94
        
Casper Hansen's avatar
Casper Hansen committed
95
96
97
98
99
100
101
102
103
104
    def forward(
        self,
        hidden_states, past_key_value=None, attention_mask=None, position_ids=None, output_attentions=False, use_cache=False
    ):
        bsz, seqlen, _ = hidden_states.shape
        xqkv = self.qkv_proj(hidden_states)
        xqkv = xqkv.view(bsz, seqlen, -1, self.n_local_heads, self.head_dim)
        xq = xqkv[:, :, 0]
        xk = xqkv[:, :, 1]
        xv = xqkv[:, :, 2]
Haotian Tang's avatar
Haotian Tang committed
105

Casper Hansen's avatar
Casper Hansen committed
106
107
108
109
        if seqlen > 1:
            xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
            xk = xk.view(bsz, seqlen, self.n_local_heads, self.head_dim)
            xv = xv.view(bsz, seqlen, self.n_local_heads, self.head_dim)
Haotian Tang's avatar
Haotian Tang committed
110

Casper Hansen's avatar
Casper Hansen committed
111
            xq, xk = self.rotary_emb(xq, xk, position_ids)
Haotian Tang's avatar
Haotian Tang committed
112

Casper Hansen's avatar
Casper Hansen committed
113
114
            self.cache_k = self.cache_k.to(xq)
            self.cache_v = self.cache_v.to(xq)
Haotian Tang's avatar
Haotian Tang committed
115

Casper Hansen's avatar
Casper Hansen committed
116
117
118
119
120
121
            values_store = xv.transpose(2, 1)
            keys_store = (
                xk.reshape(bsz, seqlen, self.n_local_heads, self.head_dim // 8, 8)
                .permute(0, 2, 3, 1, 4)
                .contiguous()
            )
Haotian Tang's avatar
Haotian Tang committed
122

Casper Hansen's avatar
Casper Hansen committed
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
            self.cache_v[:bsz, :, self.start_pos : self.start_pos + seqlen, :] = values_store
            self.cache_k[:bsz, :, :, self.start_pos : self.start_pos + seqlen, :] = keys_store

            keys = xk
            values = xv
            past_key_value = (xk, xv) if use_cache else None

            xq = xq.transpose(1, 2)
            keys = keys.transpose(1, 2)
            values = values.transpose(1, 2)
            scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
            if attention_mask is not None:
                scores = scores + attention_mask  # (bs, n_local_heads, slen, cache_len + slen)
            scores = F.softmax(scores.float(), dim=-1).type_as(xq)
            output = torch.matmul(scores, values)  # (bs, n_local_heads, slen, head_dim)
            output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
        else:
            xq = xq[:, 0, :, :]
            xk = xk[:, 0, :, :]
            xv = xv[:, 0, :, :]
            past_key_value = (xk, xv) if use_cache else None
            output = awq_inference_engine.single_query_attention(
                xq,
                xk,
                xv,
                self.cache_k,
                self.cache_v,
                None,
                None,
                self.start_pos,
                self.head_dim,
                10000,
                True,
            )
            output = output.reshape(bsz, 1, -1)
        
        attn_output = self.o_proj(output)
        
        if use_cache:
            self.start_pos += seqlen
        else:
            self.start_pos = 0
Haotian Tang's avatar
Haotian Tang committed
165
166

        return attn_output, None, past_key_value