checkpointing.py 7.29 KB
Newer Older
xingjinliang's avatar
xingjinliang committed
1
2
3
4
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.

import os
from pathlib import Path
wangxj's avatar
wangxj committed
5
from typing import Dict, Optional, Tuple
xingjinliang's avatar
xingjinliang committed
6

wangxj's avatar
wangxj committed
7
import torch
xingjinliang's avatar
xingjinliang committed
8
9
10
11
12
13
14
15
import torch.nn as nn

from megatron.core import dist_checkpointing
from megatron.training import get_args
from megatron.training.checkpointing import _load_base_checkpoint, load_checkpoint
from megatron.training.utils import print_rank_0, unwrap_model

try:
wangxj's avatar
wangxj committed
16
    import modelopt
xingjinliang's avatar
xingjinliang committed
17
18
19
20
    from modelopt.torch.opt.plugins import (
        get_sharded_modelopt_state,
        restore_modelopt_state_metadata,
    )
wangxj's avatar
wangxj committed
21
    from modelopt.torch.opt.plugins.mcore_dist_checkpointing import _get_gpt_sharded_modelopt_state
xingjinliang's avatar
xingjinliang committed
22
23
24
25
except ImportError as e:
    raise ImportError("Required `\"nvidia-modelopt[torch]\"` is not installed!") from e


wangxj's avatar
wangxj committed
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
NEMO_WEIGHT_DIR_NAMES = {
    "model_weights": "model.", 
    "weights": "module.",
}


def get_sharded_load_dir(load_dir: str) -> Tuple[str, str]:
    """
    """
    sharded_prefix = ""
    sharded_load_dir = None
    # Read the tracker file and set the iteration if this is a MLM sharded checkpoint.
    tracker_filename = os.path.join(load_dir, 'latest_checkpointed_iteration.txt')
    # If no tracker file, assuming that it is a NeMo sharded checkpoint.
    if os.path.isfile(tracker_filename):
        with open(tracker_filename, 'r') as f:
            metastring = f.read().strip()
            try:
                iteration = int(metastring)
                sharded_load_dir = Path(load_dir) / 'iter_{:07d}'.format(iteration)
            except ValueError:
                sharded_load_dir = Path(load_dir) / metastring
    else:
        for nemo_dir_name, prefix in NEMO_WEIGHT_DIR_NAMES.items():
            nemo_weight_dir = Path(load_dir) / nemo_dir_name
            if os.path.isdir(nemo_weight_dir):
                sharded_prefix = prefix
                sharded_load_dir = nemo_weight_dir
                break

    if sharded_load_dir is None:
        raise ValueError("{} is not a MLM or NeMo sharded checkpoint!".format(load_dir))

    return sharded_load_dir, sharded_prefix


xingjinliang's avatar
xingjinliang committed
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
def load_modelopt_state(load_dir: Optional[str] = None, model: Optional[nn.Module] = None) -> Dict:
    """Loading modelopt_state without loading the model.

    If --use-dist-ckpt, we try to load from the sharded modelopt_state. This will not load the model
    state_dict. Otherwise, if the checkpoint is not sharded, we load the base checkpoint (that
    contains the model state as well) and extract the modelopt_state.

    Args:
        load_dir: optionally provide a different loading path
        model: required when loading a sharded checkpoint
    """
    args = get_args()

    if load_dir is None:
        load_dir = args.load

    if args.use_dist_ckpt:
        assert model is not None, "`model` argument required when `args.use_dist_ckpt is True`"

wangxj's avatar
wangxj committed
81
        sharded_load_dir, _ = get_sharded_load_dir(load_dir)
xingjinliang's avatar
xingjinliang committed
82
83
        modelopt_state_dir = sharded_load_dir / "modelopt_state"
        if modelopt_state_dir.exists():
wangxj's avatar
wangxj committed
84
85
86
87
88
89
90
91
            common_modelopt_state = torch.load(modelopt_state_dir / "common.pt")
            extra_kwargs = {}
            for mode, mode_cfg in common_modelopt_state["modelopt_state_dict"]:
                if mode == "medusa":
                    extra_kwargs.update({"num_medusa_heads": mode_cfg["config"]["medusa_num_heads"]})
                if mode == "eagle" and modelopt.__version__ >= "0.20.0":
                    print("eagle_mode", mode_cfg["config"])
                    extra_kwargs.update({"num_eagle_layers": mode_cfg["config"]["eagle_num_layers"]})
xingjinliang's avatar
xingjinliang committed
92
93
94
            print_rank_0("Loading sharded modelopt_state ({})".format(modelopt_state_dir))
            modelopt_state = restore_modelopt_state_metadata(
                dist_checkpointing.load(
wangxj's avatar
wangxj committed
95
96
97
                    _get_gpt_sharded_modelopt_state(
                        num_layers=args.num_layers, **extra_kwargs
                    ),
xingjinliang's avatar
xingjinliang committed
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
                    modelopt_state_dir,
                )
            )
            return modelopt_state
        else:
            print_rank_0(
                "sharded modelopt_state ({}) does not exist!".format(modelopt_state_dir)
            )
            return {}
    else:
        print_rank_0("Loading modelopt_state from base checkpoint ({})".format(load_dir))
        try:
            state_dict, _, _ = _load_base_checkpoint(args.load, rank0=False)
        except Exception:
            print_rank_0("Failed to load base checkpoint via megatron _load_base_checkpoint!")
            return {}
        if state_dict is None:
            return {}
        return state_dict.get("modelopt_state", {})


def load_modelopt_checkpoint(
    model,
    optimizer=None,
    opt_param_scheduler=None,
    strict: bool = True,
wangxj's avatar
wangxj committed
124
    additional_sharded_prefix: str = "",
xingjinliang's avatar
xingjinliang committed
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
    load_arg: str = "load",
) -> None:
    """Load a sharded (untar .nemo or megatron --use-dist-ckpt) or unsharded checkpoint.

    Essentially, the function is detecting whether the checkpoint is a .nemo sharded checkpoint.
    If so, we load the sharded state_dict with additional_sharded_prefix `model.`.
    This additional prefix is tha artifact of the lightning module wrapper. Once the sharded
    state_dict is loaded, we use a state_dict pre_hook to pop this additional prefix (`model.`)
    from all state_dict keys.

    If this is not a .nemo sharded checkpoint, then this function will simply call
    load_checkpoint. See megatron.checkpointing.load_checkpoint for explanation.

    Args:
        additional_sharded_prefix: append additional prefix to align the sharded checkpoint keys.
            When loading an .nemo sharded checkpoint, this is usually `model.`. Otherwise, this is
            typically an empty string.
    """

    def _remove_prefix_state_dict_pre_hook(
        state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs,
    ):
        """Pytorch state_dict pre_hook to remove prefix of the state_dict keys."""
        if additional_sharded_prefix is None:
            return
        key_rewrite_list = []
        for key, _ in state_dict.items():
            if key.startswith(additional_sharded_prefix):
                key_rewrite_list.append(key)
        for old_key in key_rewrite_list:
            new_key = old_key[len(additional_sharded_prefix) :]
            state_dict[new_key] = state_dict.pop(old_key)

    args = get_args()
    load_dir = getattr(args, load_arg)
wangxj's avatar
wangxj committed
160
    sharded_load_dir, additional_sharded_prefix = get_sharded_load_dir(load_dir)
xingjinliang's avatar
xingjinliang committed
161

wangxj's avatar
wangxj committed
162
    unwrapped_model = unwrap_model(model)
xingjinliang's avatar
xingjinliang committed
163

wangxj's avatar
wangxj committed
164
165
166
167
168
169
170
    if args.ckpt_format == "torch":
        state_dict, checkpoint_name, release, ckpt_type = _load_base_checkpoint(
            load_dir, args, rank0=False,
        )
        model_state_dict = state_dict["model"]
        unwrapped_model[0].load_state_dict(model_state_dict, strict=False)
    elif sharded_load_dir.exists() and optimizer is None and opt_param_scheduler is None:
xingjinliang's avatar
xingjinliang committed
171
172
173
174
175
        sharded_state_dict = unwrapped_model[0].sharded_state_dict(prefix=additional_sharded_prefix)
        if additional_sharded_prefix:
            unwrapped_model[0]._register_load_state_dict_pre_hook(
                _remove_prefix_state_dict_pre_hook
            )
wangxj's avatar
wangxj committed
176
177
        model_state_dict = dist_checkpointing.load(sharded_state_dict, sharded_load_dir)
        unwrapped_model[0].load_state_dict(model_state_dict, strict=False)
xingjinliang's avatar
xingjinliang committed
178
179
    else:
        _ = load_checkpoint(model, optimizer, opt_param_scheduler, strict=strict, load_arg=load_arg)