pre_quant.py 5.85 KB
Newer Older
Ji Lin's avatar
Ji Lin committed
1
2
3
4
5
6
7
import torch
import torch.nn as nn
import tqdm
import gc
import functools
from collections import defaultdict

8
from transformers.models.bloom.modeling_bloom import BloomForCausalLM
Ji Lin's avatar
Ji Lin committed
9
10
11
12
13
from transformers.models.opt.modeling_opt import OPTForCausalLM
from transformers.models.llama.modeling_llama import LlamaForCausalLM

from .auto_scale import auto_scale_block, apply_scale
from .auto_clip import auto_clip_block, apply_clip
Casper's avatar
Casper committed
14
from ..models import MptAWQForCausalLM
Ji Lin's avatar
Ji Lin committed
15
16
17
18
19
20
21
22
23
24
25
26
27

__all__ = ["run_awq"]


def get_named_linears(module):
    return {name: m for name, m in module.named_modules() if isinstance(m, nn.Linear)}


def get_blocks(model):
    if isinstance(model, LlamaForCausalLM):
        layers = model.model.layers
    elif isinstance(model, OPTForCausalLM):
        layers = model.model.decoder.layers
28
29
30
    elif isinstance(model, BloomForCausalLM):
        layers = model.transformer.h
    elif "mpt" in str(model.__class__).lower():
Casper's avatar
Casper committed
31
        layers = MptAWQForCausalLM.get_model_layers(model)
32
33
    elif "falcon" in str(model.__class__).lower():
        layers = model.transformer.h
Ji Lin's avatar
Ji Lin committed
34
35
36
37
    else:
        raise NotImplementedError(type(model))
    return layers
    
38
39
40
41
42
43
44
45
46
47
def move_embed(model, device):
    if isinstance(model, LlamaForCausalLM):
        model.model.embed_tokens = model.model.embed_tokens.to(device)
    elif isinstance(model, OPTForCausalLM):
        model.model.decoder.embed_tokens = model.model.decoder.embed_tokens.to(device)
        model.model.decoder.embed_positions = model.model.decoder.embed_positions.to(device)
    elif isinstance(model, BloomForCausalLM):
        model.transformer.word_embeddings = model.transformer.word_embeddings.to(device)
        model.transformer.word_embeddings_layernorm = model.transformer.word_embeddings_layernorm.to(device)
    elif "mpt" in str(model.__class__).lower():
Casper's avatar
Casper committed
48
        MptAWQForCausalLM.move_embed(model, device)
49
50
51
52
    elif "falcon" in str(model.__class__).lower():
        model.transformer.word_embeddings = model.transformer.word_embeddings.to(device)
    else:
        raise NotImplementedError(type(model))
Ji Lin's avatar
Ji Lin committed
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75

@torch.no_grad()
def run_awq(
    model, enc,
    w_bit, q_config,
    n_samples=512, seqlen=512,
    auto_scale=True, mse_range=True,
    # some configs for ablation study
    calib_data="pileval",
):
    from ..utils.calib_data import get_calib_dataset
    from ..utils.module import append_str_prefix, get_op_name


    layers = get_blocks(model)

    samples = get_calib_dataset(
        data=calib_data, tokenizer=enc, n_samples=n_samples, block_size=seqlen)
    samples = torch.cat(samples, dim=0)

    inps = []
    layer_kwargs = {}

76
77
78
    layers[0] = layers[0].cuda()
    move_embed(model, "cuda")
    
Ji Lin's avatar
Ji Lin committed
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
    # get input and kwargs to layer 0
    # with_kwargs is only supported in PyTorch 2.0
    # use this Catcher hack for now
    class Catcher(nn.Module):
        def __init__(self, module):
            super().__init__()
            self.module = module

        def forward(self, inp, **kwargs):
            inps.append(inp)
            layer_kwargs.update(kwargs)
            raise ValueError  # early exit to break later inference

    # patch layer 0 to catch input and kwargs
    layers[0] = Catcher(layers[0])
    try:
        model(samples.to(next(model.parameters()).device))
    except ValueError:  # work with early exit
        pass
Abhinav Kulkarni's avatar
Abhinav Kulkarni committed
98
    del samples
Ji Lin's avatar
Ji Lin committed
99
100
101
    layers[0] = layers[0].module  # restore
    inps = inps[0]

102
103
104
    layers[0] = layers[0].cpu()
    move_embed(model, "cpu")
    
Ji Lin's avatar
Ji Lin committed
105
106
107
108
109
110
111
112
113
114
115
    gc.collect()
    torch.cuda.empty_cache()

    awq_results = {
        "scale": [],
        "clip": [],
    }

    # solve layer by layer
    for i in tqdm.tqdm(range(len(layers)), desc="Running AWQ..."):
        layer = layers[i]
116
        layer = layer.cuda()
Ji Lin's avatar
Ji Lin committed
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
        named_linears = get_named_linears(layer)

        # firstly, get input features of all linear layers
        def cache_input_hook(m, x, y, name, feat_dict):
            x = x[0]
            x = x.detach().cpu()
            feat_dict[name].append(x)

        input_feat = defaultdict(list)
        handles = []
        for name in named_linears:
            handles.append(named_linears[name].register_forward_hook(
                functools.partial(cache_input_hook, name=name,
                                  feat_dict=input_feat)))
        inps = inps.to(next(layer.parameters()).device)  # in case multi-gpu
        # get output as next layer's input
        inps = layer(inps, **layer_kwargs)[0]
        for h in handles:
            h.remove()
        # now solve for scaling and clipping
        input_feat = {k: torch.cat(v, dim=0) for k, v in input_feat.items()}

139
140
141
        # Clear GPU memory
        torch.cuda.empty_cache()

Ji Lin's avatar
Ji Lin committed
142
143
144
145
146
147
        if auto_scale:  # if it applies, we should also modify the input_feat with scales
            scales_list = auto_scale_block(
                layer, layer_kwargs,
                w_bit=w_bit, q_config=q_config,
                input_feat=input_feat,
            )
148
149
            # apply_scale(layer, scales_list, input_feat_dict=input_feat)
            apply_scale(layers[i], scales_list, input_feat_dict=input_feat)
Ji Lin's avatar
Ji Lin committed
150
151
            # append prefix to make names global
            awq_results["scale"] += append_str_prefix(scales_list, get_op_name(model, layer) + ".")
152
153
154

        # Clear GPU memory
        torch.cuda.empty_cache()
Ji Lin's avatar
Ji Lin committed
155
156
157
158
159
160
161
162
163
        
        if mse_range:
            clip_list = auto_clip_block(layer,
                            w_bit=w_bit, q_config=q_config,
                            input_feat=input_feat,)
            apply_clip(layer, clip_list)
            # append prefix to make names global
            awq_results["clip"] += append_str_prefix(clip_list, get_op_name(model, layer) + ".")

164
        layer = layer.cpu()
165
        # Haotian: check activation replacement
Ji Lin's avatar
Ji Lin committed
166
167
168
169
170
171
172
173
174
175
        del input_feat
        gc.collect()
        torch.cuda.empty_cache()
        
    return awq_results


def apply_awq(model, awq_results):
    apply_scale(model, awq_results["scale"])
    apply_clip(model, awq_results["clip"])