marlin_utils.py 8.15 KB
Newer Older
yangql's avatar
yangql committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
import gc
from logging import getLogger
from typing import Tuple

import torch
from accelerate.utils import find_tied_parameters
from safetensors.torch import save_file as safe_save
from tqdm import tqdm

from ..nn_modules.qlinear.qlinear_marlin import QuantLinear as MarlinQuantLinear
from ..nn_modules.qlinear.qlinear_marlin import _get_perms, unpack_qzeros
from ..quantization import CHECKPOINT_FORMAT, QUANT_METHOD, BaseQuantizeConfig
from .accelerate_utils import load_checkpoint_in_model
from .import_utils import MARLIN_AVAILABLE, MARLIN_EXCEPTION
from .modeling_utils import recurse_getattr, recurse_setattr


if MARLIN_AVAILABLE:
    import autogptq_marlin_cuda

logger = getLogger(__name__)


def prepare_model_for_marlin_load(
    model,
    quantize_config: BaseQuantizeConfig,
    quant_linear_class,
    torch_dtype,
    current_model_save_name,
    device_map,
):
    # The model (e.g. model.safetensors) is already serialized in the Marlin format, load it directly.
    if quantize_config.checkpoint_format == CHECKPOINT_FORMAT.MARLIN:
        model_save_name = current_model_save_name
        logger.info(f"Loading a GPTQ model, detected Marlin serialized format at {model_save_name}.")
        model = convert_to_marlin(model, quant_linear_class, quantize_config, repack=False)
    else:
        model_save_name, is_cached = quantize_config.get_cache_file_path(quant_method=QUANT_METHOD.GPTQ,
                                                              checkpoint_format=CHECKPOINT_FORMAT.MARLIN)

        # If GPTQ model has Marlin version cached locally, load from the cached version (no repacking needed).
        if is_cached:
            logger.info(
                f"Loading a GPTQ model, detected a cached repacked weight for Marlin kernel at {model_save_name}."
            )
            model = convert_to_marlin(model, quant_linear_class, quantize_config, repack=False)

        # Otherwise, convert the model to Marlin format first and cache locally.
        else:
            # Loading the GPTQ checkpoint to do the conversion.
            # TODO: Avoid loading the model with wrong QuantLinear, and directly use
            # Marlin ones. The repacking can be done directly on the safetensors, just
            # as for AWQ checkpoints.
            load_checkpoint_in_model(
                model,
                dtype=torch_dtype,  # This is very hacky but works due to https://github.com/huggingface/accelerate/blob/bd72a5f1a80d5146554458823f8aeda0a9db5297/src/accelerate/utils/modeling.py#L292
                checkpoint=current_model_save_name,
                device_map=device_map,
                offload_state_dict=True,
                offload_buffers=True,
            )
            # Convert model to marlin, repacking weights into Marlin format.
            model = convert_to_marlin(model, quant_linear_class, quantize_config, repack=True)

            # Safetensors is unable to save tied weights, so we untie them here. Reference: https://github.com/huggingface/safetensors/issues/202
            tied_params = find_tied_parameters(model)

            for weight_group in tied_params:
                for param_name in weight_group:
                    if isinstance(recurse_getattr(model, param_name), torch.nn.Parameter):
                        recurse_setattr(
                            model,
                            param_name,
                            torch.nn.Parameter(recurse_getattr(model, param_name).clone()),
                        )
                    else:
                        recurse_setattr(
                            model,
                            param_name,
                            recurse_getattr(model, param_name).clone(),
                        )

            # Cache the converted model.
            safe_save(model.state_dict(), model_save_name)

    return model, model_save_name


# Validate marlin support
def _validate_marlin_device_support() -> bool:
    """
        Validates if the current device is compatible for Marlin.
        ref: https://github.com/IST-DASLab/marlin?tab=readme-ov-file#requirements

        Returns:
            bool: indicates if CUDA device is compatible for Marlin
        """
    return torch.cuda.get_device_capability()[0] >= 8


# Adapted from https://github.com/rib-2/marlin/tree/conversion
def _validate_marlin_compatibility(cfg: BaseQuantizeConfig):
    if not MARLIN_AVAILABLE:
        return f"AutoGPTQ is not compiled with the Marlin kernel, with the following error: {MARLIN_EXCEPTION}"
    if cfg.bits != 4:
        return f"The quantized model uses a bitwidth different than 4 (found {cfg.bits})"
    if cfg.group_size != 128 and cfg.group_size != -1:
        return "The quantized model uses a group size that is not 128 or -1 (found quantization_config.group_size)"
    if not cfg.sym:
        return "The quantized model uses asymmetric quantization"
    if cfg.desc_act:
        return "The quantized model uses act-order (also called desc-act) scheme"
    if cfg.quant_method == QUANT_METHOD.AWQ:
        return "awq_gemm format is currently not compatible with marlin"
    return None


@torch.no_grad()
def convert_to_marlin(model, model_quantlinear, quantization_config: BaseQuantizeConfig, repack: bool, strict: bool = False):
    """
    Converts GPTQ-packed weights to the Marlin format. This assumes that the model already meets Marlin kernel constraints.

    Arguments:
        repack (`bool`):
            Whether to repack the qweights from `model` into the Marlin's QuantLinear layers.
    """
    if repack:
        message = "Repacking weights to be compatible with Marlin kernel..."
    else:
        # TODO: load directly Marlin QuantLinear.
        message = "Overriding QuantLinear layers to use Marlin's QuantLinear..."

    for name, module in tqdm(model.named_modules(), desc=message, total=len(list(model.named_modules()))):
        if not isinstance(module, model_quantlinear):
            continue

        parent_name = ".".join(name.split(".")[:-1])
        layer_name = name[len(parent_name) + 1 :]

        # We could use `torch.count_nonzero(module.bias) > 0` here to discard zero bias, but this has issues when
        # loading weights from checkpoints holding zero bias.
        with torch.device("meta"):
            new_module = MarlinQuantLinear(
                bits=4,
                group_size=module.group_size,
                infeatures=module.infeatures,
                outfeatures=module.outfeatures,
                bias=module.bias is not None,
                trainable=False,
            )

        # workspace is never in the state_dict, thus we need to allocate it manually.
        new_module.workspace = torch.zeros(module.outfeatures // 128 * 16, dtype=torch.int, device=module.device)

        # Dequantize the weight.
        if repack:
            marlin_repacked_weight = autogptq_marlin_cuda.gptq_repack(module.qweight)

            if strict:
                dequantized_qzeros = unpack_qzeros(module.qzeros)

                if not torch.all(dequantized_qzeros == 8):
                    raise ValueError(
                        "Marlin kernel is compatible only with checkpoints using symmetric quantization."
                        "Found non-symmetric quantization for the weight {name}."
                    )


            _, _scale_perm, _scale_perm_single = _get_perms()

            s = module.scales.data.clone()
            if module.group_size != module.infeatures:
                s = s.reshape((1, -1))
                s = s.reshape((-1, len(_scale_perm)))[:, _scale_perm]
            else:
                s = s.reshape((-1, len(_scale_perm_single)))[:, _scale_perm_single]
            s = s.reshape((-1, module.outfeatures)).contiguous()

            new_module.B = marlin_repacked_weight
            new_module.s = s
            new_module.bias = module.bias

            new_module = new_module.to(module.device)

        # Save to parent.
        parent_module = model.get_submodule(parent_name)
        setattr(parent_module, layer_name, new_module)

        # Free cuda memory.
        del module
        if repack:
            del marlin_repacked_weight
        gc.collect()

    # Set quantization config to be Marlin.
    quantization_config.checkpoint_format = CHECKPOINT_FORMAT.MARLIN

    return model