fused_utils.py 8.26 KB
Newer Older
Casper's avatar
Casper committed
1
import torch
2

Casper's avatar
Casper committed
3
4
5
6
7
8
9
10
from awq.modules.linear import (
    WQLinear_GEMM,
    WQLinear_GEMV,
    WQLinear_Marlin,
    WQLinear_Exllama,
    WQLinear_ExllamaV2,
    WQLinear_GEMVFast,
)
Casper's avatar
Casper committed
11

Casper's avatar
Casper committed
12

13
14
15
16
17
18
19
def prepare_correct_devices(next_layer, hidden_states, mask):
    hidden_states = hidden_states.to(next_layer.device)

    if mask is not None:
        mask = mask.to(next_layer.device)

    return hidden_states, mask
Casper's avatar
Casper committed
20
21


Casper's avatar
Casper committed
22
23
24
25
26
27
def prepare_cache(blocks, seqlen: int) -> int:
    for block in blocks:
        start_pos = block.attn.start_pos
        will_cache_be_exceeded = start_pos + seqlen > block.attn.max_seq_len

        # Reset and avoid retaining state when processing context
Casper's avatar
Casper committed
28
        if seqlen > 1 and (will_cache_be_exceeded or start_pos > 0):
Casper's avatar
Casper committed
29
30
31
32
33
            block.attn.start_pos = block.attn.cache.roll_kv_n_steps(
                start_pos, n=start_pos
            )

        # Slowly roll out old tokens without performance hit if exceeded during decoding
Casper's avatar
Casper committed
34
35
36
        elif seqlen == 1 and will_cache_be_exceeded:
            block.attn.start_pos = block.attn.cache.roll_kv_n_steps(start_pos, n=100)

Casper's avatar
Casper committed
37

Casper's avatar
Casper committed
38
39
40
41
42
43
44
def prepare_input_ids(input_ids: torch.Tensor, last_forward_num_tokens: int):
    # NOTE: from transformers 4.35.0, input_ids includes full context during decoding
    num_input_tokens = input_ids.shape[-1]
    num_new_tokens = num_input_tokens

    if num_input_tokens != 1:
        num_new_tokens = num_input_tokens - last_forward_num_tokens
Casper's avatar
Casper committed
45

Casper's avatar
Casper committed
46
        # after context is processed, slice to latest token
Casper's avatar
Casper committed
47
        if num_new_tokens == 1:
Casper's avatar
Casper committed
48
49
50
51
            input_ids = input_ids[:, -1:]

    return input_ids, last_forward_num_tokens + num_new_tokens

Casper's avatar
Casper committed
52

Casper's avatar
Casper committed
53
54
55
def prepare_attention_mask(seqlen, start_pos, device, type_as: torch.Tensor):
    mask = None
    if seqlen > 1:
Casper's avatar
Casper committed
56
57
58
        mask = torch.full((1, 1, seqlen, seqlen), float("-inf"), device=device)
        mask = torch.triu(mask, diagonal=start_pos + 1).type_as(type_as)

Casper's avatar
Casper committed
59
60
    return mask

Casper's avatar
Casper committed
61

Casper's avatar
Casper committed
62
def fuse_qkv(module, q_proj, k_proj, v_proj):
Casper's avatar
Casper committed
63
64
65
66
67
    bias = (
        torch.cat([q_proj.bias, k_proj.bias, v_proj.bias], dim=0)
        if q_proj.bias is not None
        else None
    )
Casper's avatar
Casper committed
68
69
70

    if isinstance(q_proj, WQLinear_GEMV):
        q_linear = WQLinear_GEMV
71
    elif isinstance(q_proj, WQLinear_GEMM):
Casper's avatar
Casper committed
72
        q_linear = WQLinear_GEMM
73
74
    elif isinstance(q_proj, WQLinear_Exllama):
        q_linear = WQLinear_Exllama
75
    elif isinstance(q_proj, WQLinear_ExllamaV2):
76
        q_linear = WQLinear_ExllamaV2
77
78
    elif isinstance(q_proj, WQLinear_Marlin):
        q_linear = WQLinear_Marlin
Casper's avatar
Casper committed
79
80
    elif isinstance(q_proj, WQLinear_GEMVFast):
        q_linear = WQLinear_GEMVFast
Casper's avatar
Casper committed
81
82
83
84
85
86
87

    qkv_layer = q_linear(
        q_proj.w_bit,
        q_proj.group_size,
        q_proj.in_features,
        q_proj.out_features + k_proj.out_features + v_proj.out_features,
        q_proj.bias is not None,
Casper's avatar
Casper committed
88
        next(iter(module.state_dict().values())).device,
Casper's avatar
Casper committed
89
90
    )

91
    if isinstance(q_proj, WQLinear_GEMV):
Casper's avatar
Casper committed
92
93
94
95
96
97
98
99
100
        qkv_layer.qweight = torch.cat(
            [q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=0
        )
        qkv_layer.qzeros = torch.cat(
            [q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=0
        )
        qkv_layer.scales = torch.cat(
            [q_proj.scales, k_proj.scales, v_proj.scales], dim=0
        )
Casper's avatar
Casper committed
101
        qkv_layer.split_k_iters = q_proj.split_k_iters
102
    elif isinstance(q_proj, WQLinear_GEMM):
Casper's avatar
Casper committed
103
104
105
106
107
108
109
110
111
        qkv_layer.qweight = torch.cat(
            [q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=1
        )
        qkv_layer.qzeros = torch.cat(
            [q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=1
        )
        qkv_layer.scales = torch.cat(
            [q_proj.scales, k_proj.scales, v_proj.scales], dim=1
        )
112
    elif isinstance(q_proj, WQLinear_Exllama):
Casper's avatar
Casper committed
113
114
115
116
117
118
119
120
121
        qkv_layer.qweight = torch.cat(
            [q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=1
        )
        qkv_layer.qzeros = torch.cat(
            [q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=1
        )
        qkv_layer.scales = torch.cat(
            [q_proj.scales, k_proj.scales, v_proj.scales], dim=1
        )
122
    elif isinstance(q_proj, WQLinear_ExllamaV2):
Casper's avatar
Casper committed
123
124
125
126
127
128
129
130
131
        qkv_layer.qweight = torch.cat(
            [q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=1
        )
        qkv_layer.qzeros = torch.cat(
            [q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=1
        )
        qkv_layer.scales = torch.cat(
            [q_proj.scales, k_proj.scales, v_proj.scales], dim=1
        )
132
    elif isinstance(q_proj, WQLinear_Marlin):
Casper's avatar
Casper committed
133
134
135
136
137
138
        qkv_layer.qweight = torch.cat(
            [q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=1
        )
        qkv_layer.scales = torch.cat(
            [q_proj.scales, k_proj.scales, v_proj.scales], dim=1
        )
139
        # workspace is created in post_init
Casper's avatar
Casper committed
140
141
142
143
144
145
146
147
148
149
150
    elif isinstance(q_proj, WQLinear_GEMVFast):
        qkv_layer.qweight = torch.cat(
            [q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=0
        )
        qkv_layer.qzeros = torch.cat(
            [q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=1
        ).contiguous()
        qkv_layer.scales = torch.cat(
            [q_proj.scales, k_proj.scales, v_proj.scales], dim=1
        ).contiguous()
        qkv_layer.split_k_iters = q_proj.split_k_iters
Casper's avatar
Casper committed
151

Casper's avatar
Casper committed
152
153
    qkv_layer.bias = bias

Casper's avatar
Casper committed
154
155
156
    for layer in [q_proj, k_proj, v_proj]:
        del (layer.qweight, layer.qzeros, layer.scales)

Casper's avatar
Casper committed
157
    return qkv_layer
158

Casper's avatar
Casper committed
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182

def fuse_linears(linears, device, dim=1, operation=torch.cat):
    total_out_features = sum([layer.out_features for layer in linears])
    fused = WQLinear_GEMM(
        linears[0].w_bit,
        linears[0].group_size,
        linears[0].in_features,
        total_out_features,
        bias=None,
        dev=device,
    )
    fused.qweight = operation([layer.qweight for layer in linears], dim=dim)
    fused.qzeros = operation([layer.qzeros for layer in linears], dim=dim)
    fused.scales = operation([layer.scales for layer in linears], dim=dim)

    for layer in linears:
        del (layer.qweight, layer.qzeros, layer.scales, layer)

    return fused


def get_attention_shapes(
    attention_shapes, max_seq_len, cache_batch_size, n_heads, n_kv_heads, head_dim
):
183
184
185
186
187
188
    if attention_shapes is not None:
        attention_shapes = attention_shapes

    elif n_kv_heads == 0:
        attention_shapes = {
            # following fastertransformer definition
Casper's avatar
Casper committed
189
190
191
192
193
194
            "cache_v": (
                cache_batch_size,
                n_heads,
                max_seq_len,
                head_dim,
            ),
195
            # 8: pack 8 fp16 in FT, if fp32 then use 4
Casper's avatar
Casper committed
196
197
198
199
200
201
202
            "cache_k": (
                cache_batch_size,
                n_heads,
                head_dim // 8,
                max_seq_len,
                8,
            ),
203
204
205
206
207
208
209
210
211
212
            "xqkv_view": (-1, n_heads, head_dim),
            "xq_slice": lambda xqkv: xqkv[:, :, 0],
            "xk_slice": lambda xqkv: xqkv[:, :, 1],
            "xv_slice": lambda xqkv: xqkv[:, :, 2],
            "xq_view": (n_heads, head_dim),
            "xk_view": (n_heads, head_dim),
            "xv_view": (n_heads, head_dim),
            "xk_reshape": (n_heads, head_dim // 8, 8),
            "single_xq_view": (n_heads, head_dim),
            "single_xk_view": (n_heads, head_dim),
Casper's avatar
Casper committed
213
            "single_xv_view": (n_heads, head_dim),
214
215
216
217
218
        }

    else:
        attention_shapes = {
            # following fastertransformer definition
Casper's avatar
Casper committed
219
220
221
222
223
224
            "cache_v": (
                cache_batch_size,
                n_kv_heads,
                max_seq_len,
                head_dim,
            ),
225
            # 8: pack 8 fp16 in FT, if fp32 then use 4
Casper's avatar
Casper committed
226
227
228
229
230
231
232
            "cache_k": (
                cache_batch_size,
                n_kv_heads,
                head_dim // 8,
                max_seq_len,
                8,
            ),
233
            "xqkv_view": (n_heads + n_kv_heads * 2, head_dim),
Casper's avatar
Casper committed
234
            "xq_slice": lambda xqkv: xqkv[:, :, 0:n_heads],
235
            "xk_slice": lambda xqkv: xqkv[:, :, n_heads : (n_heads + n_kv_heads)],
Casper's avatar
Casper committed
236
            "xv_slice": lambda xqkv: xqkv[:, :, -n_kv_heads:],
237
238
239
240
241
242
            "xq_view": (n_heads, head_dim),
            "xk_view": (n_kv_heads, head_dim),
            "xv_view": (n_kv_heads, head_dim),
            "xk_reshape": (n_kv_heads, head_dim // 8, 8),
            "single_xq_view": (n_heads, head_dim),
            "single_xk_view": (n_kv_heads, head_dim),
Casper's avatar
Casper committed
243
            "single_xv_view": (n_kv_heads, head_dim),
244
        }
Casper's avatar
Casper committed
245
246

    return attention_shapes