base_loader.py 22.4 KB
Newer Older
yuguo960516's avatar
yuguo960516 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import collections
import copy
import logging
import os

import omegaconf
import oneflow as flow
from termcolor import colored

import libai.utils.distributed as dist
from libai.config import LazyCall
from libai.models.build import build_model

logger = logging.getLogger(__name__)


WEIGHTS_NAME_PT = "pytorch_model.bin"
CONFIG_NAME = "config.json"


def _load_state_dict_into_model(model_to_load, state_dict, start_prefix):
    """load state dict into model

    Args:
        model_to_load (nn.Module): Model to be loaded.
        state_dict (OrderedDict): State dict of pretrained model.
        start_prefix (str): Start prefix.

    Returns:
        list: error message about loading.
    """
    metadata = getattr(state_dict, "_metadata", None)
    state_dict = state_dict.copy()
    if metadata is not None:
        state_dict._metadata = metadata

    error_msgs = []

    def load(module, prefix=""):
        local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
        args = (state_dict, prefix, local_metadata, True, [], [], error_msgs)
        module._load_from_state_dict(*args)

        for name, child in module._modules.items():
            if child is not None:
                load(child, prefix + name + ".")

    load(model_to_load, prefix=start_prefix)

    return error_msgs


class ModelLoader(object):
    def __init__(self, model, libai_cfg, pretrained_model_path, **kwargs):
        """Class used to load the [`transformers`](https://huggingface.co/models) pretrained model
        or `OneFlow` pretrained model.

        Args:
            model (libai.models): Model to be loaded in Libai.
            libai_cfg (dict): The config of model in LiBai, you can import it from
                `libai.config.configs.common.models`.
            pretrained_model_path (str): The directory path of pretrained model,
                which contains model weights file and config file.
            output_loading_info (`bool`, *optional*, defaults to `False`):
                Whether to return a dictionary containing missing keys, unexpected keys
                and error messages.
        """
        self.model = model
        self.libai_cfg = libai_cfg
        self.pretrained_model_path = pretrained_model_path
        self.kwargs = kwargs
        self.output_loading_info = kwargs.pop("output_loading_info", False)

    def _state_dict_to_global(self, flow_state_dict=None, mode="libai"):
        """Tensor in OneFlow state dict to global according to model's sbp and placement.

        Args:
            flow_state_dict (OrderedDict): State dict of OneFlow's pretrained model.
        """
        assert mode in ["libai", "pytorch"], f"not support for mode {mode}"
        if mode == "libai" or dist.is_main_process():
            prefix = self.base_model_prefix_2

            # Checkpoint
            has_prefix_module = any(
                s.startswith(self.base_model_prefix_2) for s in flow_state_dict.keys()
            )
            # Module
            expects_prefix_module = any(
                s.startswith(prefix) for s in self.model.state_dict().keys()
            )

            start_prefix = "" if has_prefix_module else prefix + "."
            loaded_keys = [start_prefix + key for key in flow_state_dict.keys()]
        else:
            prefix, has_prefix_module, expects_prefix_module, loaded_keys = [None] * 4
            flow_state_dict = collections.OrderedDict()

        prefix = dist.broadcast_py_object(prefix, src=0)
        has_prefix_module = dist.broadcast_py_object(has_prefix_module, src=0)
        expects_prefix_module = dist.broadcast_py_object(expects_prefix_module, src=0)
        loaded_keys = dist.broadcast_py_object(loaded_keys, src=0)

        # to global
        for key, value in self.model.state_dict().items():
            if not expects_prefix_module:
                key = prefix + "." + key
            if key in loaded_keys:
                if not has_prefix_module:
                    key = ".".join(key.split(".")[1:])

                if mode == "pytorch":
                    flow_state_dict[key] = flow.to_global(
                        flow_state_dict[key] if dist.is_main_process() else flow.Tensor(None),
                        sbp=flow.sbp.broadcast,
                        placement=flow.placement("cpu", ranks=[0]),
                    )

                flow_state_dict[key] = flow.to_global(
                    flow_state_dict[key],
                    sbp=value.sbp,
                    placement=flow.placement("cpu", ranks=list(value.placement.ranks)),
                )
        return flow_state_dict

    def _load_pretrained_model(
        self,
        model,
        state_dict,
        pretrained_model_path,
        ignore_mismatched_sizes=False,
    ):
        """Load pretrained model.

        Args:
            model (libai.models): The model to be loaded.
            state_dict (OrderedDict): state dict.
            loaded_keys (list): keys of state dict.
            pretrained_model_path (str): pretrained modelE path.
            ignore_mismatched_sizes (bool):
                Whether or not to raise an error if some of the weights
                from the checkpoint do not have the same size as the
                weights of the model, defaults to `False`.
        """
        model_state_dict = model.state_dict()
        expected_keys = list(model_state_dict.keys())
        prefix = self.base_model_prefix_2

        loaded_keys = state_dict.keys()
        if len(prefix) > 0:
            has_prefix_module = any(s.startswith(prefix) for s in loaded_keys)
            expects_prefix_module = any(s.startswith(prefix) for s in expected_keys)
        else:
            has_prefix_module = False
            expects_prefix_module = False

        remove_prefix_from_model = not has_prefix_module and expects_prefix_module
        add_prefix_to_model = has_prefix_module and not expects_prefix_module

        if remove_prefix_from_model:
            expected_keys_not_prefixed = [s for s in expected_keys if not s.startswith(prefix)]
            expected_keys = [
                ".".join(s.split(".")[1:]) if s.startswith(prefix) else s for s in expected_keys
            ]
        elif add_prefix_to_model:
            expected_keys = [".".join([prefix, s]) for s in expected_keys]

        missing_keys = list(set(expected_keys) - set(loaded_keys))
        unexpected_keys = list(set(loaded_keys) - set(expected_keys))

        start_prefix = ""
        model_to_load = model
        if (
            len(self.base_model_prefix_2) > 0
            and not hasattr(model, self.base_model_prefix_2)
            and has_prefix_module
        ):
            start_prefix = self.base_model_prefix_2 + "."
        if (
            len(self.base_model_prefix_2) > 0
            and hasattr(model, self.base_model_prefix_2)
            and not has_prefix_module
        ):
            model_to_load = getattr(model, self.base_model_prefix_2)
            if any(key in expected_keys_not_prefixed for key in loaded_keys):
                raise ValueError("The state dict of the model you are loading is corrupted.")

        def _find_mismatched_keys(
            state_dict,
            model_state_dict,
            loaded_keys,
            add_prefix_to_model,
            remove_prefix_from_model,
            ignore_mismatched_sizes,
        ):
            mismatched_keys = []
            if ignore_mismatched_sizes:
                for checkpoint_key in loaded_keys:
                    model_key = checkpoint_key
                    if remove_prefix_from_model:
                        model_key = f"{prefix}.{checkpoint_key}"
                    elif add_prefix_to_model:
                        model_key = ".".join(checkpoint_key.split(".")[1:])

                    if (
                        model_key in model_state_dict
                        and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape
                    ):
                        mismatched_keys.append(
                            (
                                checkpoint_key,
                                state_dict[checkpoint_key].shape,
                                model_state_dict[model_key].shape,
                            )
                        )
                        del state_dict[checkpoint_key]
            return mismatched_keys

        if state_dict is not None:
            mismatched_keys = _find_mismatched_keys(
                state_dict,
                model_state_dict,
                loaded_keys,
                add_prefix_to_model,
                remove_prefix_from_model,
                ignore_mismatched_sizes,
            )
            error_msgs = _load_state_dict_into_model(model_to_load, state_dict, start_prefix)

        if dist.get_local_rank() == 0:
            if len(error_msgs) > 0:
                error_msg = "\n\t".join(error_msgs)
                raise RuntimeError(
                    f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}"
                )
            if len(unexpected_keys) > 0:
                logger.warning(
                    f"Some weights of the model checkpoint at {pretrained_model_path} "
                    "were not used when "
                    f"initializing {model.__class__.__name__}:\n {unexpected_keys}\n"
                )
            else:
                logger.info(
                    f"All model checkpoint weights were used when initializing "
                    f"{model.__class__.__name__}.\n"
                )
            if len(missing_keys) > 0:
                logger.warning(
                    f"Some weights of {model.__class__.__name__} were not initialized "
                    f"from the model checkpoint at {pretrained_model_path}:\n "
                    f"{missing_keys} \n"
                )
            elif len(mismatched_keys) == 0:
                logger.info(
                    f"All the weights of {model.__class__.__name__} were initialized "
                    f"from the model checkpoint at {pretrained_model_path}.\n"
                )
            if len(mismatched_keys) > 0:
                mismatched_warning = "\n".join(
                    [
                        f"- {key}: found shape {shape1} in the checkpoint and {shape2}"
                        "in the model instantiated"
                        for key, shape1, shape2 in mismatched_keys
                    ]
                )
                logger.warning(
                    f"Some weights of {model.__class__.__name__} were not initialized"
                    f"from the model checkpoint at {pretrained_model_path} "
                    f"and are newly initialized because the shapes did not"
                    f"match:\n{mismatched_warning}\n"
                )

        return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs


class ModelLoaderLiBai(ModelLoader):
    """Class used to load `OneFlow` pretrained model.

    Args:
        model (libai.models): Model to be loaded in Libai.
        libai_cfg (dict): The config of model in LiBai, you can import it from
            `libai.config.configs.common.models`.
        pretrained_model_path (str): The directory path of pretrained model,
            which contains model weights file and config file.
        output_loading_info (`bool`, *optional*, defaults to `False`):
            Whether to return a dictionary containing missing keys, unexpected keys
            and error messages.
    """

    def __init__(self, model, libai_cfg, pretrained_model_path, **kwargs):
        super().__init__(model, libai_cfg, pretrained_model_path, **kwargs)
        self.base_model_prefix_2 = None  # prefix in LiBai

    def _load_flow_state_dict(self, state_dict_file):
        # load oneflow_model
        state_dict = flow.load(state_dict_file, global_src_rank=0)
        return state_dict

    def load(self):
        """Load model.

        # For example:

        # .. code-block:: python

            >>> import libai
            >>> from libai.config.configs.common.models.bert import cfg
            >>> from model_utils import BertLoaderLiBai

            >>> loder = BertLoaderLiBai(
                    libai.models.BertModel,
                    cfg,
                    'path/bert-base-chinese'
                )
            >>> bert = loder.load()

        """

        if dist.is_main_process():
            assert os.path.isdir(
                self.pretrained_model_path
            ), f"{self.pretrained_model_path} must be a directory"

        flow_state_dict = self._load_flow_state_dict(self.pretrained_model_path)

        # Instance model
        if isinstance(self.model, omegaconf.dictconfig.DictConfig):
            self.model.cfg = self.libai_cfg
            self.model = build_model(self.model)
        else:
            self.model = build_model(LazyCall(self.model)(cfg=self.libai_cfg))

        # State_dict to global
        self._state_dict_to_global(flow_state_dict, mode="libai")

        # Load
        (
            model,
            missing_keys,
            unexpected_keys,
            mismatched_keys,
            error_msgs,
        ) = self._load_pretrained_model(self.model, flow_state_dict, self.pretrained_model_path)

        if self.output_loading_info:
            loading_info = {
                "missing_keys": missing_keys,
                "unexpected_keys": unexpected_keys,
                "mismatched_keys": mismatched_keys,
                "error_msgs": error_msgs,
            }
            return model, loading_info
        return model


class ModelLoaderHuggerFace(ModelLoader):
    """Class used to load the [`transformers`](https://huggingface.co/models)
    pretrained model.
    """

    def __init__(self, model, libai_cfg, pretrained_model_path, **kwargs):
        super().__init__(model, libai_cfg, pretrained_model_path, **kwargs)
        self.base_model_prefix_1 = None  # prefix in Transformers
        self.base_model_prefix_2 = None  # prefix in LiBai
        self.origin_libai_cfg = copy.deepcopy(self.libai_cfg)
        self.changed_keys = set()  # Store the changed configuration

    def _convert_tensor(self, tensor):
        """Convert PyTorch tensor to OneFlow tensor.

        Args:
            tensor (torch.Tensor): The source tensor.

        Returns:
            flow.Tensor: The target tensor.
        """
        tensor = tensor.float()
        return flow.Tensor(tensor.detach().cpu().numpy())

    def _convert_tensors(self, torch_state_dict):

        for k, v in torch_state_dict.items():
            torch_state_dict[k] = self._convert_tensor(v)

        return torch_state_dict

    def _fix_key(self, state_dict):
        """Fix the key in state dict: Convert "gamma" to "weight" and "beta" to "bias".

        Args:
            state_dict (OrderedDict): state dict of pretrained model.

        Returns:
            OrderedDict: State dict after fix key.
        """
        old_keys = []
        new_keys = []
        for key in state_dict.keys():
            new_key = None
            if "gamma" in key:
                new_key = key.replace("gamma", "weight")
            if "beta" in key:
                new_key = key.replace("beta", "bias")
            if new_key:
                old_keys.append(key)
                new_keys.append(new_key)
        for old_key, new_key in zip(old_keys, new_keys):
            state_dict[new_key] = state_dict.pop(old_key)
        return state_dict

    def _fix_qkv_ordering(
        self, qkv, head_size, num_heads, hidden_size=None, checkpoint_version=0.0
    ):
        # TODO(xzp): Different versions checkpoint

        hidden_size = (head_size * num_heads) if hidden_size is None else hidden_size
        num_of_qkv = qkv.shape[0] // (head_size * num_heads)
        mode = "weight" if qkv.ndim > 1 else "bias"
        if mode == "weight":
            qkv = qkv.view([num_of_qkv, num_heads, head_size, hidden_size])
            qkv = (
                qkv.permute(1, 0, 2, 3)
                .contiguous()
                .view(num_of_qkv * head_size * num_heads, hidden_size)
            )
        elif mode == "bias":
            qkv = qkv.view(num_of_qkv, num_heads, head_size)
            qkv = qkv.permute(1, 0, 2).contiguous().view(-1)
        return qkv

    def _convert_state_dict(self, flow_state_dict, cfg):
        """A function used to convert the checkpoint file of Huggingface to LiBai.

        Args:
            torch_state_dict (OrderedDict): torch state dict.
            cfg (dict): model's default config dict in LiBai.

        Returns:
            OrderedDict: flow state dict.
        """
        raise NotImplementedError("_convert_state_dict not implemented")

    def _load_config_from_json(self, config_file):
        """load config from `config.json`, and update default config.

        Args:
            config_file (str): Path of config file.
        """

        raise NotImplementedError("_load_config_from_json not implemented")

    def _load_torch_state_dict(self, state_dict_file):
        try:
            import torch
        except ImportError:
            raise ImportError("Load torch state dict need torch.")

        # load pytorch_model.bin
        state_dict = torch.load(state_dict_file, map_location="cpu")
        return state_dict

    def _update_cfg(self, keys_libai, value_target):
        """Update the libai_cfg according to target_cfg.

        Args:
            keys_libai (str): The key of libai_cfg.
            value_target (int | float): The value of target_cfg.
        """
        if keys_libai not in self.libai_cfg.keys():
            return
        if self.libai_cfg[keys_libai] != value_target:
            self.libai_cfg[keys_libai] = value_target

    def _update_cfg_log(self):
        if dist.get_local_rank() == 0:
            for key in sorted(self.libai_cfg):
                if self.origin_libai_cfg[key] == self.libai_cfg[key]:
                    continue
                self.changed_keys.add(key)
                temp_key = colored(key, "yellow")
                logger.info(
                    f"changed libai model cfg {temp_key} : "
                    f"{self.origin_libai_cfg[key]} -> {self.libai_cfg[key]} "
                )
            logger.warning(
                "The following model configurations has been modified according "
                "to `config.json` or kwargs: \n"
                f"{self.changed_keys} \n"
            )

            if dist.get_pipeline_parallel_size() > 1:
                logger.warning(
                    colored(
                        "If you use pipeline parallel, please "
                        "confirm the setting of `train.dist.pipeline_num_layers` \n",
                        "red",
                    )
                )

    def load(self):
        """Load model.

        # For example:

        # .. code-block:: python

            >>> import libai
            >>> from configs.common.models.bert import cfg
            >>> from libai.models.utils import BertLoaderHugger

            >>> loader = BertLoaderHugger(
                    libai.models.BertModel,
                    cfg,
                    'path/bert-base-chinese'
                )
            >>> bert = loader.load()

        """
        if dist.is_main_process():
            if os.path.isdir(self.pretrained_model_path):
                # state_dict file pytorch
                if os.path.isfile(os.path.join(self.pretrained_model_path, WEIGHTS_NAME_PT)):
                    model_file = os.path.join(self.pretrained_model_path, WEIGHTS_NAME_PT)
                else:
                    raise EnvironmentError(
                        f"Error no file named {WEIGHTS_NAME_PT} found"
                        f"in directory {self.pretrained_model_path}."
                    )

                # config file
                if os.path.isfile(os.path.join(self.pretrained_model_path, CONFIG_NAME)):
                    config_file = os.path.join(self.pretrained_model_path, CONFIG_NAME)

                    # Load config and update config.
                    self._load_config_from_json(config_file)
                else:
                    import warnings

                    warnings.warn(
                        f"Error no file named {CONFIG_NAME} found in directory"
                        f"{self.pretrained_model_path}",
                        RuntimeWarning,
                    )
            else:
                raise EnvironmentError(f"{self.pretrained_model_path} is not a directory.")

            logger.info("loading torch model...")
            torch_state_dict = self._load_torch_state_dict(model_file)
            torch_state_dict = self._fix_key(torch_state_dict)
            logger.info("transfering torch model into oneflow model...")
            flow_state_dict = self._convert_tensors(torch_state_dict)
            flow_state_dict = self._convert_state_dict(torch_state_dict, self.libai_cfg)
        else:
            flow_state_dict = None

        self.libai_cfg = dist.broadcast_py_object(self.libai_cfg, src=0)

        # Instance model
        logger.info("building LiBai model...")
        if isinstance(self.model, omegaconf.dictconfig.DictConfig):
            self.model.cfg = self.libai_cfg
            self.model = build_model(self.model)
        else:
            self.model = build_model(LazyCall(self.model)(cfg=self.libai_cfg))

        # State_dict to global
        logger.info("transfering state_dict local to global...")
        flow_state_dict = self._state_dict_to_global(flow_state_dict, mode="pytorch")

        logger.info("loading model weights into LiBai...")
        # Load
        (
            model,
            missing_keys,
            unexpected_keys,
            mismatched_keys,
            error_msgs,
        ) = self._load_pretrained_model(self.model, flow_state_dict, self.pretrained_model_path)

        if self.output_loading_info:
            loading_info = {
                "missing_keys": missing_keys,
                "unexpected_keys": unexpected_keys,
                "mismatched_keys": mismatched_keys,
                "error_msgs": error_msgs,
            }
            return model, loading_info
        return model