test_flattened_resharding.py 10.8 KB
Newer Older
xingjinliang's avatar
xingjinliang committed
1
2
3
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.

import io
silencealiang's avatar
add  
silencealiang committed
4
from contextlib import nullcontext
xingjinliang's avatar
xingjinliang committed
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21

import numpy as np
import pytest
import torch
from torch.distributed.checkpoint import CheckpointException

from megatron.core import parallel_state
from megatron.core.dist_checkpointing import ShardedTensor, load, save
from megatron.core.dist_checkpointing.core import CheckpointingException, maybe_load_config
from megatron.core.dist_checkpointing.dict_utils import diff
from megatron.core.dist_checkpointing.mapping import ShardedObject, ShardedTensorFactory
from megatron.core.dist_checkpointing.serialization import load_tensors_metadata
from megatron.core.dist_checkpointing.strategies.resharding import (
    apply_nd_flattened_tensors_reformulation,
    restore_nd_flattened_tensors_formulation,
)
from megatron.core.dist_checkpointing.strategies.torch import get_reformulation_metadata
silencealiang's avatar
add  
silencealiang committed
22
23
24
25
from megatron.core.dist_checkpointing.validation import (
    determine_global_metadata,
    validate_sharding_integrity,
)
xingjinliang's avatar
xingjinliang committed
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
from tests.unit_tests.dist_checkpointing import TempNamedDir
from tests.unit_tests.test_utilities import Utils


class TestFlattenedResharding:
    def setup_method(self, method):
        pass

    def teardown_method(self, method):
        Utils.destroy_model_parallel()

    @pytest.mark.parametrize(
        ('src_tp_pp', 'dest_tp_pp'),
        [((2, 4), (2, 4)), ((2, 4), (2, 2)), ((2, 4), (4, 2)), ((8, 1), (1, 2))],
    )
    def test_partition_change_save_load(self, tmp_path_dist_ckpt, src_tp_pp, dest_tp_pp):
        Utils.initialize_model_parallel(*src_tp_pp)
        with TempNamedDir(
            tmp_path_dist_ckpt / 'test_flattened_partition_change_save_load'
        ) as ckpt_dir:

            state_dict = self._build_state_dict()

            save(state_dict, ckpt_dir)

            # change TPxPP
            Utils.destroy_model_parallel()
            Utils.initialize_model_parallel(*dest_tp_pp)
            loaded_state_dict = load(self._build_state_dict(random=True), ckpt_dir)
            expected_state_dict = {k: v.data for k, v in self._build_state_dict().items()}

            diffs = diff(expected_state_dict, loaded_state_dict)
            assert not any(diffs), diffs

        Utils.destroy_model_parallel()

    @pytest.mark.parametrize(
        ('src_tp_pp', 'dest_tp_pp', 'expected_ckpt_offsets_by_rank'),
        [
            (
                (2, 4),
                (2, 2),
                {
                    0: [(0, 0, 0), (0, 0, 10)],  # TP 0, DP 0, PP 0
                    1: [(4, 0, 0), (4, 0, 10)],  # TP 1, DP 0, PP 0
                    2: [(0, 0, 0), (0, 0, 10)],  # TP 0, DP 1, PP 0
                    3: [(4, 0, 0), (4, 0, 10)],  # TP 1, DP 1, PP 0
                    4: [(0, 0, 20), (0, 0, 30)],  # TP 0, DP 0, PP 1
                    5: [(4, 0, 20), (4, 0, 30)],  # TP 1, DP 0, PP 1
                    6: [(0, 0, 20), (0, 0, 30)],  # TP 0, DP 1, PP 1
                    7: [(4, 0, 20), (4, 0, 30)],  # TP 1, DP 1, PP 1
                },
            ),
            ((8, 1), (1, 2), {rank: [(tp, 0, 0) for tp in range(8)] for rank in range(8)}),
        ],
    )
    def test_reformulate_nd_flattened_tensors(
        self, tmp_path_dist_ckpt, src_tp_pp, dest_tp_pp, expected_ckpt_offsets_by_rank
    ):
        Utils.initialize_model_parallel(*src_tp_pp, order='tp-dp-pp')
        with TempNamedDir(tmp_path_dist_ckpt / 'test_reformulate_nd_flattened_tensors') as ckpt_dir:

            state_dict = self._build_state_dict()

            ckpt_local_shape = state_dict['sd_key_flat'].local_shape

            save(state_dict, ckpt_dir)

            # change TPxPP
            Utils.destroy_model_parallel()
            Utils.initialize_model_parallel(*dest_tp_pp, order='tp-dp-pp')
            load_state_dict = self._build_state_dict(random=True)

            reformulation_metadata = get_reformulation_metadata(load_state_dict, ckpt_dir)
            reformulated_state_dict, formulation_restore_data = (
                apply_nd_flattened_tensors_reformulation(load_state_dict, reformulation_metadata)
            )
            assert isinstance(reformulated_state_dict['sd_key_unflat'], ShardedTensor)
            assert isinstance(reformulated_state_dict['sd_key_flat'], dict)

            assert reformulated_state_dict['sd_key_flat'].keys() == set(
                (offset, ckpt_local_shape) for offset in expected_ckpt_offsets_by_rank[Utils.rank]
            ), (
                reformulated_state_dict['sd_key_flat'].keys(),
                ckpt_local_shape,
                expected_ckpt_offsets_by_rank[Utils.rank],
            )

            # We can even load the reformulated state dict with a high-level API
            loaded_state_dict = load(
                reformulated_state_dict, ckpt_dir, validate_access_integrity=False
            )
            loaded_state_dict = restore_nd_flattened_tensors_formulation(
                loaded_state_dict, formulation_restore_data
            )
            expected_state_dict = {k: v.data for k, v in self._build_state_dict().items()}
            diffs = diff(expected_state_dict, loaded_state_dict)
            assert not any(diffs), diffs

        Utils.destroy_model_parallel()

    @pytest.mark.parametrize(('src_tp_pp',), [((2, 4),), ((8, 1),), ((1, 1),), ((1, 4),)])
    def test_load_tensor_metadata(self, tmp_path_dist_ckpt, src_tp_pp):
        Utils.initialize_model_parallel(*src_tp_pp, order='tp-dp-pp')
        with TempNamedDir(tmp_path_dist_ckpt / 'test_reformulate_nd_flattened_tensors') as ckpt_dir:

            state_dict = self._build_state_dict()

            save(state_dict, ckpt_dir)

            # change TPxPP
            Utils.destroy_model_parallel()
            Utils.initialize_model_parallel(1, 1)

            sharded_metadata = load_tensors_metadata(ckpt_dir)

            for attr_name in ('local_shape', 'global_shape'):
                flat_val = getattr(sharded_metadata['flat'], attr_name)
                unflat_val = getattr(sharded_metadata['unflat'], attr_name)
                assert flat_val == unflat_val, (attr_name, flat_val, unflat_val)

            for sh_ten in sharded_metadata.values():
                sh_ten.replica_id = Utils.rank
            loaded_state_dict = load(sharded_metadata, ckpt_dir)
            assert torch.all(
                loaded_state_dict['unflat'] == torch.arange(8 * 5 * 40).reshape(8, 5, 40)
            )
            assert torch.all(loaded_state_dict['flat'] == torch.arange(8 * 5 * 40))

        Utils.destroy_model_parallel()

    def _build_state_dict(self, random=False):
        tp_rank = parallel_state.get_tensor_model_parallel_rank()
        tp_size = parallel_state.get_tensor_model_parallel_world_size()
        pp_rank = parallel_state.get_pipeline_model_parallel_rank()
        pp_size = parallel_state.get_pipeline_model_parallel_world_size()
        dp_rank = parallel_state.get_data_parallel_rank()
        dp_size = parallel_state.get_data_parallel_world_size()

        init_fn = torch.rand if random else torch.arange
        global_ten = init_fn(8 * 5 * 40).reshape(8, 5, 40)
        local_ten = global_ten
        local_ten = local_ten.chunk(tp_size, dim=0)[tp_rank]
        local_ten = local_ten.chunk(pp_size, dim=2)[pp_rank]
        assert local_ten.shape == (8 // tp_size, 5, 40 // pp_size)

        local_ten_size_by_dp = local_ten.numel()
        assert local_ten_size_by_dp % dp_size == 0, (local_ten_size_by_dp, dp_size)
        local_ten_size_by_dp = local_ten_size_by_dp // dp_size
        # make a bit shifted DP slices so that they are not equal
        start_jitter = dp_rank
        end_jitter = dp_rank + 1 if dp_rank + 1 < dp_size else 0
        local_dp_slice = slice(
            local_ten_size_by_dp * dp_rank + start_jitter,
            local_ten_size_by_dp * (dp_rank + 1) + end_jitter,
        )
        local_flat_ten = local_ten.flatten()[local_dp_slice]
        if dp_rank == dp_size - 1:
            assert local_flat_ten.numel() == local_ten_size_by_dp - dp_rank
        else:
            assert local_flat_ten.numel() == local_ten_size_by_dp + 1

        state_dict = {
            'sd_key_unflat': ShardedTensor.from_rank_offsets(
                'unflat',
                local_ten,
                (0, tp_rank, tp_size),
                (2, pp_rank, pp_size),
                replica_id=dp_rank,
            ),
            'sd_key_flat': ShardedTensor.from_rank_offsets_flat(
                'flat',
                local_flat_ten,
                local_ten.shape,
                (0, tp_rank, tp_size),
                (2, pp_rank, pp_size),
                flattened_range=local_dp_slice,
            ),
        }
        return state_dict
silencealiang's avatar
add  
silencealiang committed
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268

    def test_flattened_tensors_are_properly_validated(self, tmp_path_dist_ckpt):
        Utils.initialize_model_parallel()
        # Global tensor of shape (6, 6) is built from:
        # ranks 0, 1, 2 tensors of length 1, 2, 3
        # and then ranks 3, ..., 7 tensors of length 6
        local_flat_ten = torch.ones(Utils.rank + 1 if Utils.rank <= 2 else 6) * Utils.rank

        global_flattened_len = 6 + (Utils.world_size - 3) * 6
        if Utils.world_size == 8:
            assert global_flattened_len == 1 + 2 + 3 + 5 * 6
            local_ten_shape = (1, 6)
        else:
            local_ten_shape = (global_flattened_len,)

        if Utils.rank == 0:
            local_dp_slice_start = 0
        elif Utils.rank == 1:
            local_dp_slice_start = 1
        elif Utils.rank == 2:
            local_dp_slice_start = 3
        else:
            local_dp_slice_start = 0
        local_dp_slice = slice(local_dp_slice_start, local_dp_slice_start + len(local_flat_ten))

        state_dict = {
            'sd_key_flat': ShardedTensor.from_rank_offsets_flat(
                'flat',
                local_flat_ten,
                local_ten_shape,
                *((0, max(0, Utils.rank - 2), 6),) if Utils.world_size == 8 else (),
                flattened_range=local_dp_slice,
                replica_id=0
            )
        }
        validate_sharding_integrity(determine_global_metadata(state_dict)[1])
        if Utils.rank == 1:
            old_state_dict = state_dict
            state_dict = {}

        with (
            pytest.raises(CheckpointingException) if Utils.rank == 0 else nullcontext()
        ) as exc_info:
            validate_sharding_integrity(determine_global_metadata(state_dict)[1])
        if Utils.rank == 0:
            assert 'Flattened ranges dont cover the whole shard ShardedTensor' in str(
                exc_info.value
            )

        if Utils.rank == 1:
            state_dict = old_state_dict

        if Utils.rank == 4:
            state_dict = {}

        with (
            pytest.raises(CheckpointingException) if Utils.rank == 0 else nullcontext()
        ) as exc_info:
            validate_sharding_integrity(determine_global_metadata(state_dict)[1])
        if Utils.rank == 0:
            assert 'Invalid access pattern' in str(exc_info.value)

        Utils.destroy_model_parallel()