extract_scales.py 15.5 KB
Newer Older
libo11's avatar
libo11 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
import argparse
import glob
import json
import os
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple

import numpy as np
import torch
from safetensors.torch import safe_open

from vllm.model_executor.layers.quantization.schema import QuantParamSchema


# Adapted from vllm/model_executor/model_loader/weight_utils.py
# The main differences are that we add the NPZ format and simplify
# its functionality drastically for our purposes (e.g. we assume that
# the quantized model exists locally and there is no need to download it)
def _prepare_hf_weights(
    quantized_model_dir: str,
    load_format: str = "auto",
    fall_back_to_pt: bool = True,
) -> Tuple[str, List[str], bool]:
    if not os.path.isdir(quantized_model_dir):
        raise FileNotFoundError(
            f"The quantized model directory `{quantized_model_dir}` "
            "does not exist.")
    use_safetensors = False
    # Some quantized models use .pt files for storing the weights.
    if load_format == "auto":
        allow_patterns = ["*.safetensors", "*.bin"]
    elif load_format == "safetensors":
        use_safetensors = True
        allow_patterns = ["*.safetensors"]
    elif load_format == "pt":
        allow_patterns = ["*.pt"]
    elif load_format == "npz":
        allow_patterns = ["*.npz"]
    else:
        raise ValueError(f"Unknown load_format: {load_format}")
    if fall_back_to_pt:
        allow_patterns += ["*.pt"]

    hf_weights_files: List[str] = []
    for pattern in allow_patterns:
        hf_weights_files += glob.glob(
            os.path.join(quantized_model_dir, pattern))
        if len(hf_weights_files) > 0:
            if pattern == "*.safetensors":
                use_safetensors = True
            break

    if not use_safetensors:
        # Exclude files that are not needed for inference.
        # https://github.com/huggingface/transformers/blob/v4.34.0/src/transformers/trainer.py#L227-L233
        blacklist = [
            "training_args.bin",
            "optimizer.bin",
            "optimizer.pt",
            "scheduler.pt",
            "scaler.pt",
        ]
        hf_weights_files = [
            f for f in hf_weights_files
            if not any(f.endswith(x) for x in blacklist)
        ]

    if len(hf_weights_files) == 0:
        raise RuntimeError(
            f"Cannot find any model weights with `{quantized_model_dir}`")

    return hf_weights_files, use_safetensors


# Adapted from vllm/model_executor/model_loader/weight_utils.py
def _hf_tensorfile_iterator(filename: str, load_format: str,
                            use_safetensors: bool):
    if load_format == "npz":
        assert not use_safetensors
        with np.load(filename) as data:
            for name in data.files:
                param = torch.from_numpy(data[name])
                yield name, param
    elif use_safetensors:
        with safe_open(filename, framework="pt") as f:
            for name in f.keys():  # NOQA: SIM118
                param = f.get_tensor(name)
                yield name, param
    else:
        state = torch.load(filename, map_location="cpu")
        for name, param in state.items():
            yield name, param
        del state
        torch.cuda.empty_cache()


def _kv_scales_extractor(
        hf_tensor_files: Iterable[str],
        use_safetensors: bool,
        rank_keyword: str = "rank",
        expected_tp_size: Optional[int] = None) -> Dict[int, Dict[int, float]]:
    """
    Given a list of files containing tensor data, attempt to extract KV cache
    scales from these files. Intended as a helper function taking in the output
    from _prepare_hf_weights.
    Args:
    rank_keyword        Matches the number immediately after this keyword in the
                        tensor filename to determine the TP rank corresponding
                        to said tensor file
    expected_tp_size    If specified, the TP size of the tensor files is checked
                        against this and an error is raised if they don't match.
    Returns a dictionary mapping TP ranks to their relevant KV cache scales.
    The per-rank scales are themselves represented as a dictionary of layer
    indices to the respective per-layer scale.
    """
    for char in rank_keyword:
        assert not char.isdecimal(
        ), f"Rank keyword {rank_keyword} contains a numeric character!"
    rank_scales_map = {}
    for tensor_file in hf_tensor_files:
        try:
            rank_idx = tensor_file.find(rank_keyword)
            if rank_idx != -1:
                start_idx = rank_idx + len(rank_keyword)
                stop_idx = start_idx
                while stop_idx < len(
                        tensor_file) and tensor_file[stop_idx].isdecimal():
                    stop_idx += 1
                if stop_idx == start_idx:
                    raise RuntimeError("Did not find rank # in filename.")
                rank = int(tensor_file[start_idx:stop_idx])
            elif len(hf_tensor_files) == 1:
                # Since there is only one tensor file, we can assume
                # that it's intended for TP rank 0
                rank = 0
            else:
                raise RuntimeError(
                    f"Filename does not contain '{rank_keyword}'.")
        except RuntimeError:
            print("Unable to determine TP rank "
                  f"corresponding to file '{tensor_file}'")
            raise

        if rank not in rank_scales_map:
            layer_scales_map = {}
            rank_scales_map[rank] = layer_scales_map
        else:
            raise RuntimeError(
                f"Tensor file '{tensor_file}' shares TP rank {rank} "
                "with another tensor file.")

        module_delimiter = ":" if args.load_format == "npz" else "."
        for name, param in _hf_tensorfile_iterator(tensor_file,
                                                   args.load_format,
                                                   use_safetensors):
            if "kv_cache_scaling_factor" in name:
                nums = [
                    int(s) for s in name.split(module_delimiter)
                    if s.isdecimal()
                ]
                assert len(
                    nums) == 1, f"Could not determine layer idx for {name}"
                layer_idx = nums[0]
                assert layer_idx not in layer_scales_map, f"Duplicate scaling"\
                    f" factor corresponding to layer {layer_idx}"
                try:
                    layer_scales_map[layer_idx] = param.item()
                except RuntimeError:
                    print(
                        "This utility supports only per-tensor scalar scales "
                        f"for now. The tensor\n {name} = {param} \nis an "
                        "invalid scale factor.")
                    raise

    if all(
            len(layer_scales_map) == 0
            for layer_scales_map in rank_scales_map.values()):
        # Note: this is true even if the rank_scales_map is empty
        print("WARNING: No KV cache scale factors found. No output saved.")
        return None
    empirical_tp_world_size = max(rank_scales_map.keys()) + 1
    if expected_tp_size is not None:
        assert expected_tp_size == empirical_tp_world_size, \
            f"User expected TP world size = {expected_tp_size} " \
            "from model but tool is expecting TP world size = " \
            f"{empirical_tp_world_size} from model instead."
    for i in range(empirical_tp_world_size):
        assert i in rank_scales_map, "Expected TP world size = "\
            f"{empirical_tp_world_size} but did not find KV " \
            f"cache scaling factors for TP rank {i}"
    print(f"Found TP world size = {empirical_tp_world_size} "
          "when extracting KV cache scales!")
    return rank_scales_map


def _metadata_extractor(quantized_model_dir: str,
                        metadata_extract_fns: \
                        Dict[str, Callable[[Dict[str, Any]], Any]]) \
                        -> Dict[str, Any]:
    """
    Given a directory containing quantized model files, this function
    aims to extract metadata from the JSON files within this directory.
    Each JSON file is expected to represent a dictionary in JSON
    format (referred to as a "JSON-dictionary"). Metadata extraction is
    defined by a dictionary called metadata_extract_fns, where each
    metadata field name is mapped to an extraction function.

    These extraction functions are designed to take a JSON-dictionary
    as their only argument  and return the corresponding metadata.
    While extraction functions are permitted to raise  exceptions, they
    should only raise a KeyError or ValueError if the metadata field
    cannot  be extracted from the current JSON-dictionary, yet there's
    a possibility of finding it in another JSON-dictionary.

    The function returns a dictionary that maps metadata fields to
    their extracted data. The keys of this dictionary correspond exactly
    to those in metadata_extract_fns. If any fields fail to be extracted,
    their corresponding values are set to None, and a warning is printed.
    """
    if not os.path.isdir(quantized_model_dir):
        raise FileNotFoundError(
            f"The quantized model directory `{quantized_model_dir}` "
            "does not exist.")
    metadata_files = glob.glob(os.path.join(quantized_model_dir, "*.json"))

    result = {}
    for file in metadata_files:
        with open(file) as f:
            try:
                metadata = json.load(f)
            except json.JSONDecodeError:
                print(f"Could not parse `{file}` as a valid metadata file,"
                      " skipping it.")
                continue
            if not isinstance(metadata, dict):
                print(f"The file `{file}` does not correspond to a "
                      "JSON-serialized dictionary, skipping it.")
                continue
            for metadata_name, extract_fn in metadata_extract_fns.items():
                try:
                    metadata_info = extract_fn(metadata)
                    if metadata_name not in result:
                        result[metadata_name] = metadata_info
                    elif metadata_info != result[metadata_name]:
                        raise RuntimeError(
                            "Metadata mismatch! Originally found "
                            f"{metadata_name} = {result[metadata_name]} but "
                            f"now found {metadata_name} = {metadata_info} in "
                            f"`{file}`")
                except KeyError:
                    # It is possible that a given file does not contain some
                    # of our selected metadata as it could be located in some
                    # other metadata file.
                    # 'EFINAE': extract_fn failure is not an error.
                    pass
                except ValueError:
                    # See above.
                    pass

    # Warn if we cannot find any of the requested metadata
    for metadata_name in metadata_extract_fns:
        if metadata_name not in result:
            print("WARNING: Unable to find requested metadata field "
                  f"`{metadata_name}`, setting it to None.")
            result[metadata_name] = None

    return result


def main(args):
    metadata_extract_fns = {
        "model_type": lambda json_dict: json_dict["layers"][0]["decoder_type"],
        "tp_size": lambda json_dict: int(json_dict["tensor_parallel"]),
        "model_dtype": lambda json_dict: json_dict["dtype"]
    }
    recovered_metadata = _metadata_extractor(args.quantized_model,
                                             metadata_extract_fns)
    if args.tp_size is not None:
        metadata_tp_size = recovered_metadata["tp_size"]
        if metadata_tp_size is not None:
            assert args.tp_size == metadata_tp_size, \
              f"User expected TP world size = {args.tp_size} " \
              f"but found TP world size = {metadata_tp_size} from metadata!"
    expected_tp_size = args.tp_size or recovered_metadata["tp_size"]
    rank_keyword = "rank"
    hf_tensor_files, use_safetensors = _prepare_hf_weights(
        args.quantized_model, args.load_format)
    rank_scales_map = _kv_scales_extractor(hf_tensor_files, use_safetensors,
                                           rank_keyword, expected_tp_size)
    # Postprocess: formatting to the current schema. Consider pulling it
    # out into a dedicated function should it ever become more complicated.
    rank_scales_map = {
        rank: {k: scale[k]
               for k in sorted(scale.keys())}
        for rank, scale in rank_scales_map.items()
    }
    # TODO: Expand this with activation and weights scaling factors when
    # they are used in the future
    schema = QuantParamSchema(
        model_type=recovered_metadata["model_type"],
        kv_cache={
            "dtype": ("float8_e4m3fn" if len(rank_scales_map) > 0 else
                      recovered_metadata["model_dtype"]),
            "scaling_factor":
            rank_scales_map
        },
    )

    if args.output_dir is None:
        output_file = os.path.join(args.quantized_model, args.output_name)
    else:
        if not os.path.isdir(args.output_dir):
            os.makedirs(args.output_dir, exist_ok=True)
        output_file = os.path.join(args.output_dir, args.output_name)

    with open(output_file, 'w') as f:
        f.write(schema.model_dump_json(indent=4))
        print(f"Completed! KV cache scaling factors saved to {output_file}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="This simple utility extracts the "
        "KV cache scaling factors from a quantized HF model "
        "and saves them to a JSON file compatible with later "
        "use by vLLM (pass this file to the appropriate "
        "runtime typically using the argument "
        "--quantization-param-path <filename>). This is only used "
        "if the KV cache dtype is FP8 and on ROCm (AMD GPU).")
    parser.add_argument(
        "--quantized_model",
        help="Specify the directory containing a single quantized HF model. "
        "It is expected that the quantization format is FP8_E4M3, for use "
        "on ROCm (AMD GPU).",
        required=True)
    parser.add_argument(
        "--load_format",
        help="Optionally specify the format of the model's tensor files "
        "containing the KV cache scaling factors.",
        choices=["auto", "safetensors", "npz", "pt"],
        default="auto")
    parser.add_argument(
        "--output_dir",
        help="Optionally specify the output directory. By default the "
        "KV cache scaling factors will be saved in the model directory, "
        "however you can override this behavior here.",
        default=None)
    parser.add_argument(
        "--output_name",
        help="Optionally specify the output filename.",
        # TODO: Change this once additional scaling factors are enabled
        default="kv_cache_scales.json")
    parser.add_argument(
        "--tp_size",
        help="Optionally specify the tensor-parallel (TP) size that the "
        "quantized model should correspond to. If specified, during KV "
        "cache scaling factor extraction the observed TP size will be "
        "checked against this and an error will be raised if there is "
        "a mismatch. If not specified, the quantized model's expected "
        "TP size is instead inferred from the largest TP rank observed. "
        "The expected TP size is cross-checked against the TP ranks "
        "observed in the quantized model and an error is raised if any "
        "discrepancies are found.",
        default=None,
        type=int)
    args = parser.parse_args()

    main(args)