common.py 9.18 KB
Newer Older
xingjinliang's avatar
xingjinliang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import math

import torch

from megatron.core import parallel_state
from megatron.core.dist_checkpointing import load, load_plain_tensors, save
from megatron.core.dist_checkpointing.dict_utils import diff
from megatron.core.dist_checkpointing.serialization import (
    get_default_load_sharded_strategy,
    get_default_save_sharded_strategy,
)
from megatron.core.dist_checkpointing.strategies.fully_parallel import (
    FullyParallelLoadStrategyWrapper,
    FullyParallelSaveStrategyWrapper,
)
from megatron.core.dist_checkpointing.validation import StrictHandling
from tests.unit_tests.dist_checkpointing import TempNamedDir
from tests.unit_tests.test_utilities import Utils


def common_test_simple_sharded_state_dict_save_load(
    initialize_model_fn, tmp_path_dist_ckpt, src_layer_spec_fn, dst_layer_spec_fn
):
    """Simple save and load sanity check, without any equality tests."""
    tp = 2
    pp = 4
    Utils.initialize_model_parallel(tp, pp)
    gpt_model = initialize_model_fn(
        1, src_layer_spec_fn, tensor_model_parallel_size=tp, pipeline_model_parallel_size=pp
    )
    with TempNamedDir(tmp_path_dist_ckpt / 'test_gpt_model') as ckpt_dir:
        # Save
        sharded_state_dict = gpt_model.sharded_state_dict()
        save(sharded_state_dict, ckpt_dir)

        # Load
        gpt_model = initialize_model_fn(
            2, dst_layer_spec_fn, tensor_model_parallel_size=tp, pipeline_model_parallel_size=pp
        )
        sharded_state_dict = gpt_model.sharded_state_dict()
        state_dict, missing_keys, unexpected_keys = load(
            sharded_state_dict, ckpt_dir, strict=StrictHandling.RETURN_ALL
        )
        # Potential mismatch is because of extra states which is ok
        assert all('_extra_state' in k for k in missing_keys)
        assert all('_extra_state' in k for k in unexpected_keys)
        gpt_model.load_state_dict(state_dict)
    Utils.destroy_model_parallel()


def common_test_parallel_reconfiguration_e2e(
    initialize_model_fn,
    tmp_path_dist_ckpt,
    src_tp_pp,
    dest_tp_pp,
    src_layer_spec_fn,
    dst_layer_spec_fn,
    use_fpsl,
    load_order="tp-dp-pp",
    store_order="tp-dp-pp",
    src_tp_pp_kwargs=None,
    dst_tp_pp_kwargs=None,
):
    """Test model saving and loading with different TP/PP"""
    Utils.initialize_model_parallel(*src_tp_pp, **(src_tp_pp_kwargs or {}), order=load_order)
    with TempNamedDir(
        tmp_path_dist_ckpt / 'test_gpt_model_reconfiguration_model_A'
    ) as ckpt_dir_A, TempNamedDir(
        tmp_path_dist_ckpt / 'test_gpt_model_reconfiguration_model_B'
    ) as ckpt_dir_B:
        # Save checkpoint A
        gpt_model_A = initialize_model_fn(
            1,
            src_layer_spec_fn,
            tensor_model_parallel_size=src_tp_pp[0],
            pipeline_model_parallel_size=src_tp_pp[1],
        )
        save_strategy = get_default_save_sharded_strategy()
        if use_fpsl:
            save_strategy = FullyParallelSaveStrategyWrapper(
                save_strategy,
                parallel_state.get_data_parallel_group(with_context_parallel=True),
                True,
            )
        save(gpt_model_A.sharded_state_dict(), ckpt_dir_A, save_strategy)
        regular_state_dict_A = gpt_model_A.state_dict()
        Utils.destroy_model_parallel()

        # Load checkpoint A with different TP/PP and save as checkpoint B
        # No FPS this time, only FPL
        Utils.initialize_model_parallel(*dest_tp_pp, **(dst_tp_pp_kwargs or {}), order=store_order)
        gpt_model_B = initialize_model_fn(
            2,
            dst_layer_spec_fn,
            tensor_model_parallel_size=dest_tp_pp[0],
            pipeline_model_parallel_size=dest_tp_pp[1],
        )
        if use_fpsl:
            load_strategy = get_default_load_sharded_strategy(ckpt_dir_A)
            load_strategy = FullyParallelLoadStrategyWrapper(load_strategy)
        else:
            load_strategy = None
        state_dict, missing_keys, unexpected_keys = load(
            gpt_model_B.sharded_state_dict(),
            ckpt_dir_A,
            load_strategy,
            strict=StrictHandling.RETURN_ALL,
        )
        # Potential mismatch is because of extra states which is ok
        assert all('_extra_state' in k for k in missing_keys)
        assert all('_extra_state' in k for k in unexpected_keys)
        gpt_model_B.load_state_dict(state_dict)
        save(gpt_model_B.sharded_state_dict(), ckpt_dir_B)
        regular_state_dict_B = gpt_model_A.state_dict()
        Utils.destroy_model_parallel()

        # Test both checkpoints are equal
        Utils.initialize_model_parallel(1, 1)
        plain_state_dict_A = load_plain_tensors(ckpt_dir_A)
        plain_state_dict_B = load_plain_tensors(ckpt_dir_B)
        diffs = diff(plain_state_dict_A, plain_state_dict_B)
        assert not any(map(bool, diffs)), diffs

        # Test both regular state dicts are equal, turning FP8 states to bytes first
        regular_state_dict_A = {
            k: v for k, v in regular_state_dict_A.items() if not k.endswith('_extra_state')
        }
        regular_state_dict_B = {
            k: v for k, v in regular_state_dict_B.items() if not k.endswith('_extra_state')
        }
        diffs = diff(regular_state_dict_A, regular_state_dict_B)
        assert not any(map(bool, diffs)), diffs
        Utils.destroy_model_parallel()


def common_test_state_dict_comparison(initialize_model_fn, tmp_path_dist_ckpt):
    tp = 2
    pp = 4
    Utils.initialize_model_parallel(tp, pp)
    with TempNamedDir(
        tmp_path_dist_ckpt / 'test_state_dict_comparison_A'
    ) as ckpt_dir_A, TempNamedDir(
        tmp_path_dist_ckpt / 'test_state_dict_comparison_B'
    ) as ckpt_dir_B:
        gpt_model_A = initialize_model_fn(
            1, tensor_model_parallel_size=tp, pipeline_model_parallel_size=pp
        )
        save(gpt_model_A.sharded_state_dict(), ckpt_dir_A)
        gpt_model_B = initialize_model_fn(
            2, tensor_model_parallel_size=tp, pipeline_model_parallel_size=pp
        )
        save(gpt_model_B.sharded_state_dict(), ckpt_dir_B)

        state_dict_A = load_plain_tensors(ckpt_dir_A)
        state_dict_A_dup = load_plain_tensors(ckpt_dir_A)
        state_dict_B = load_plain_tensors(ckpt_dir_B)

        # Test that A matches A
        diffs = diff(state_dict_A, state_dict_A_dup)
        assert not any(map(bool, diffs)), diffs

        # Test that A *keys* match B *keys*, but the tensors content is different
        only_left, only_right, mismatch = diff(state_dict_A, state_dict_B)
        assert not only_left and not only_right, (only_left, only_right)
        assert len(mismatch) == len(state_dict_A), (len(mismatch), (len(state_dict_A)))
    Utils.destroy_model_parallel()


def common_test_vocab_size_padding_change(
    initialize_model_fn, tmp_path_dist_ckpt, vocab_size_base, src_tp_pp, dest_tp_pp
):
    """Test model loading with different vocab size (caused by TP padding)."""

    def get_test_vocab_size(make_divisible_by=128):
        divisor = make_divisible_by * parallel_state.get_tensor_model_parallel_world_size()
        return int(math.ceil(vocab_size_base / divisor)) * divisor

    vocab_size_dependent_keys = {
        'output_layer.weight',
        'output_layer.bias',
        'embedding.word_embeddings.weight',
    }

    with TempNamedDir(
        tmp_path_dist_ckpt / 'test_vocab_size_padding_change_A'
    ) as ckpt_dir_A, TempNamedDir(
        tmp_path_dist_ckpt / 'test_vocab_size_padding_change_B'
    ) as ckpt_dir_B:
        # Save checkpoint A
        Utils.initialize_model_parallel(*src_tp_pp)
        gpt_model_A = initialize_model_fn(
            1,
            tensor_model_parallel_size=src_tp_pp[0],
            pipeline_model_parallel_size=src_tp_pp[1],
            vocab_size=get_test_vocab_size(),
        )
        save(gpt_model_A.sharded_state_dict(), ckpt_dir_A)
        Utils.destroy_model_parallel()

        # Load checkpoint A with different TP/PP and save as checkpoint B
        Utils.initialize_model_parallel(*dest_tp_pp)
        gpt_model_B = initialize_model_fn(
            2,
            tensor_model_parallel_size=dest_tp_pp[0],
            pipeline_model_parallel_size=dest_tp_pp[1],
            vocab_size=get_test_vocab_size(),
        )
        state_dict = load(gpt_model_B.sharded_state_dict(), ckpt_dir_A)
        gpt_model_B.load_state_dict(state_dict)
        save(gpt_model_B.sharded_state_dict(), ckpt_dir_B)
        Utils.destroy_model_parallel()

        # Test equality
        Utils.initialize_model_parallel(1, 1)
        plain_state_dict_A = load_plain_tensors(ckpt_dir_A)
        plain_state_dict_B = load_plain_tensors(ckpt_dir_B)
        # Test vocab size dependent keys are equal up to `vocab_size_base`
        for vocab_layer_key in vocab_size_dependent_keys:
            if vocab_layer_key in plain_state_dict_A:
                ten_A = plain_state_dict_A.pop(vocab_layer_key)
                ten_B = plain_state_dict_B.pop(vocab_layer_key)
                assert torch.all(
                    ten_A[:vocab_size_base] == ten_B[:vocab_size_base]
                ), vocab_layer_key

        # Test other tensors are equal
        diffs = diff(plain_state_dict_A, plain_state_dict_B)
        assert not any(map(bool, diffs)), diffs
        Utils.destroy_model_parallel()