test_fully_parallel.py 16.5 KB
Newer Older
silencealiang's avatar
silencealiang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
from pathlib import Path
from typing import List, Tuple
from unittest import mock

import pytest
import torch
import torch.distributed

from megatron.core import parallel_state
from megatron.core.dist_checkpointing import ShardedTensor
from megatron.core.dist_checkpointing.dict_utils import (
    dict_list_map_outplace,
    map_reduce,
    nested_values,
)
from megatron.core.dist_checkpointing.exchange_utils import _get_empty_tensor_for_exchange
from megatron.core.dist_checkpointing.mapping import (
    ShardedObject,
    ShardedStateDict,
    ShardedTensorFactory,
    is_main_replica,
)
from megatron.core.dist_checkpointing.strategies.base import (
    LoadShardedStrategy,
    SaveShardedStrategy,
)
from megatron.core.dist_checkpointing.strategies.fully_parallel import (
    FullyParallelLoadStrategyWrapper,
    FullyParallelSaveStrategyWrapper,
    _sharded_tensor_shard_id,
)
from tests.unit_tests.dist_checkpointing import TempNamedDir
from tests.unit_tests.test_utilities import Utils


class MockSaveStrategy(SaveShardedStrategy):
    def __init__(self):
        super().__init__('mock', 1)
        self.save_keys = set()

    def save(self, sharded_state_dict, ckpt_dir):
        for sh_ten in nested_values(sharded_state_dict):
            if is_main_replica(sh_ten.replica_id):
                self.save_keys.add(sh_ten.key)


class MockLoadStrategy(LoadShardedStrategy):
    def __init__(self, device='cpu'):
        super().__init__()
        self.device = device
        self.load_keys = set()

    def load(self, sharded_state_dict, ckpt_dir):
        for sh_ten in nested_values(sharded_state_dict):
            if is_main_replica(sh_ten.replica_id):
                self.load_keys.add(sh_ten.key)

        def load_rand(x):
            assert isinstance(x, ShardedTensor) or isinstance(x, ShardedObject)
            if isinstance(x, ShardedTensor):
                x.init_data(self.device)
                x.data.fill_(Utils.rank)
                return x.data
            else:
                x.data = [Utils.rank]
                return x.data

        return dict_list_map_outplace(load_rand, sharded_state_dict)

    def load_tensors_metadata(self, checkpoint_dir: Path):
        pass

    def check_backend_compatibility(self, loaded_version):
        pass

    def check_version_compatibility(self, loaded_version):
        pass


class TestFullyParallelSaveAndLoad:
    def setup_method(self, method):
        pass

    def teardown_method(self, method):
        Utils.destroy_model_parallel()

    @staticmethod
    def get_sharded_state_dict():
        return {
            'sd_key_tp_repl1': ShardedTensor.from_rank_offsets(
                'key_TP_repl1',
                torch.ones(10),
                (
                    0,
                    parallel_state.get_tensor_model_parallel_rank(),
                    parallel_state.get_tensor_model_parallel_world_size(),
                ),
                replica_id=parallel_state.get_data_parallel_rank(with_context_parallel=True),
            ),
            'sd_key_tp_repl2': ShardedTensor.from_rank_offsets(
                'key_TP_repl2',
                torch.ones(10),
                (
                    0,
                    parallel_state.get_tensor_model_parallel_rank(),
                    parallel_state.get_tensor_model_parallel_world_size(),
                ),
                replica_id=parallel_state.get_data_parallel_rank(with_context_parallel=True),
            ),
            'sd_keyB': ShardedTensor.from_rank_offsets(
                'keyB', torch.ones(20), (0, Utils.rank, Utils.world_size)
            ),
            'sd_keyE_no_C': ShardedTensor.from_rank_offsets(
                'keyC', torch.ones(100), replica_id=Utils.rank
            ),
            'sd_keyX_no_D': ShardedTensor.from_rank_offsets(
                'keyD', torch.ones(1000), replica_id=Utils.rank
            ),
            'sd_keyC_no_E': ShardedTensor.from_rank_offsets(
                'keyE', torch.ones(100), replica_id=Utils.rank
            ),
        }

    @pytest.mark.parametrize("parallelization_along_dp", [False, True])
    def test_save_distribution(self, parallelization_along_dp, tmp_path_dist_ckpt):
        Utils.initialize_model_parallel(2, 1)
        state_dict = self.get_sharded_state_dict()

        # Ranks assignment:
        # 1. Lowest coverage
        # 2. Largest tensor
        # 3. Shard id (key)
        if not parallelization_along_dp:
            expected_key_to_saving_ranks = {
                'keyB': list(
                    range(Utils.world_size)
                ),  # everyone must save (disjoint shards, coverage == 1)
                'key_TP_repl1': [0, 1],  # lowest coverage (4), first TP domain
                'key_TP_repl2': [2, 3],  # lowest coverage (4), second TP domain
                'keyD': [4],  # largest tensor
                'keyC': [5],  # second largest tensor
                'keyE': [6],  # second largest tensor
            }
        else:
            if parallel_state.get_tensor_model_parallel_rank() == 0:
                expected_key_to_saving_ranks = {
                    # everyone must save (disjoint shards, coverage == 1):
                    'keyB': list(
                        range(
                            parallel_state.get_data_parallel_world_size(with_context_parallel=True)
                        )
                    ),
                    # this time, TP sharded tensors have the same coverage as fully replicated!
                    'keyD': [0],  # largest tensor
                    'keyC': [1],  # second largest tensor
                    'keyE': [2],  # second largest tensor
                    'key_TP_repl1': [3],  # smallest tensor
                    'key_TP_repl2': [3],  # smallest tensor, last rank is the least occupied
                }
            else:
                expected_key_to_saving_ranks = {
                    # everyone must save (disjoint shards, coverage == 1):
                    'keyB': list(
                        range(
                            parallel_state.get_data_parallel_world_size(with_context_parallel=True)
                        )
                    ),
                    # tensors C, D, E are absent in this DP group
                    'key_TP_repl1': [0],  # smallest tensor
                    'key_TP_repl2': [1],  # smallest tensor, last rank is the least occupied
                }

        parallelization_group = (
            parallel_state.get_data_parallel_group(with_context_parallel=True)
            if parallelization_along_dp
            else None
        )
        dp_rank = torch.distributed.get_rank(parallelization_group)
        expected_keys_saved_by_current_rank = {
            k for k, v in expected_key_to_saving_ranks.items() if dp_rank in v
        }

        # Run save and tests
        mock_strategy = MockSaveStrategy()
        save_strategy = FullyParallelSaveStrategyWrapper(
            mock_strategy, parallelization_group, do_cache_distribution=True
        )
        with TempNamedDir(tmp_path_dist_ckpt / 'mock_dir') as ckpt_dir_A:
            save_strategy.save(state_dict, ckpt_dir_A)
        key_to_saving_rank = dict(
            map_reduce(
                save_strategy.cached_distribution.main_rank_for_shard.items(),
                lambda shard_rank: shard_rank[0][0],
                lambda shard_rank: shard_rank[1],
            )
        )
        assert expected_key_to_saving_ranks == key_to_saving_rank

        for _, sh_ten in state_dict.items():
            if (
                _sharded_tensor_shard_id(sh_ten)
                in save_strategy.cached_distribution.shards_in_this_group
            ):
                is_expected_to_be_saved_by_this_rank = dp_rank in expected_key_to_saving_ranks.get(
                    sh_ten.key, []
                )
                assert sh_ten.replica_id == int(
                    not is_expected_to_be_saved_by_this_rank
                ), expected_key_to_saving_ranks

        assert mock_strategy.save_keys == expected_keys_saved_by_current_rank, (
            Utils.rank,
            mock_strategy.save_keys,
            expected_keys_saved_by_current_rank,
        )

    @pytest.mark.parametrize("parallelization_along_dp", [False, True])
    def test_load_distribution(self, parallelization_along_dp, tmp_path_dist_ckpt):
        Utils.initialize_model_parallel(2, 1)

        state_dict = self.get_sharded_state_dict()

        # Ranks assignment:
        # 1. Lowest coverage
        # 2. Largest tensor
        # 3. Shard id (key)
        if not parallelization_along_dp:
            expected_key_to_saving_ranks = {
                'keyB': list(
                    range(Utils.world_size)
                ),  # everyone must save (disjoint shards, coverage == 1)
                'key_TP_repl1': [0, 1],  # lowest coverage (4), first TP domain
                'key_TP_repl2': [2, 3],  # lowest coverage (4), second TP domain
                'keyD': [4],  # largest tensor
                'keyC': [5],  # second largest tensor
                'keyE': [6],  # second largest tensor
            }
        else:
            # When loading, expected key distribution is the same across TP, because every replica
            # needs to be loaded
            expected_key_to_saving_ranks = {
                # everyone must load (disjoint shards, coverage == 1):
                'keyB': list(
                    range(parallel_state.get_data_parallel_world_size(with_context_parallel=True))
                ),
                # this time, TP sharded tensors have the same coverage as fully replicated!
                'keyD': [0],  # largest tensor
                'keyC': [1],  # second largest tensor
                'keyE': [2],  # second largest tensor
                'key_TP_repl1': [3],  # smallest tensor
                'key_TP_repl2': [3],  # smallest tensor, last rank is the least occupied
            }

        parallelization_group = (
            parallel_state.get_data_parallel_group(with_context_parallel=True)
            if parallelization_along_dp
            else None
        )
        dp_rank = torch.distributed.get_rank(parallelization_group)
        expected_keys_saved_by_current_rank = {
            k for k, v in expected_key_to_saving_ranks.items() if dp_rank in v
        }

        # Run save and tests
        mock_strategy = MockLoadStrategy()
        load_strategy = FullyParallelLoadStrategyWrapper(
            mock_strategy, parallelization_group, do_cache_distribution=True
        )
        with TempNamedDir(tmp_path_dist_ckpt / 'mock_dir') as ckpt_dir_A:
            loaded_state_dict = load_strategy.load(state_dict, ckpt_dir_A)
        key_to_saving_rank = dict(
            map_reduce(
                load_strategy.cached_distribution.main_rank_for_shard.items(),
                lambda shard_rank: shard_rank[0][0],
                lambda shard_rank: shard_rank[1],
            )
        )
        assert expected_key_to_saving_ranks == key_to_saving_rank

        assert mock_strategy.load_keys == expected_keys_saved_by_current_rank, (
            Utils.rank,
            mock_strategy.load_keys,
            expected_keys_saved_by_current_rank,
        )

        assert loaded_state_dict.keys() == state_dict.keys()

    @pytest.mark.parametrize('state_dict_device', ['cpu', 'cuda'])
    @pytest.mark.flaky
    @pytest.mark.flaky_in_dev
    def test_memory_usage(self, state_dict_device, tmp_path_dist_ckpt):
        Utils.initialize_model_parallel(2, 1)

        megabytes = 1024 * 1024
        mock_strategy = MockLoadStrategy(state_dict_device)

        mem_alloc = []

        real_get_empty_tensor_for_exchange = _get_empty_tensor_for_exchange

        def mock_get_empty_tensor_for_exchange(*args, **kwargs) -> torch.Tensor:
            ret = real_get_empty_tensor_for_exchange(*args, **kwargs)
            mem_alloc.append(torch.cuda.memory_allocated())
            return ret

        load_strategy = FullyParallelLoadStrategyWrapper(mock_strategy)
        torch.distributed.barrier()

        # Each tensor is 4MB, 40MB in total.
        # We expect extra memory usage peak at ~32MB, not 1GB
        sharded_state_dict = {
            f'ten_{i}': ShardedTensor.from_rank_offsets(
                f'ten_{i}',
                torch.rand(megabytes, dtype=torch.float, device=state_dict_device),
                (0, Utils.rank, Utils.world_size),
            )
            for i in range(10)
        }

        mem_alloc_start = torch.cuda.memory_allocated()

        with mock.patch(
            'megatron.core.dist_checkpointing.exchange_utils._get_empty_tensor_for_exchange',
            new=mock_get_empty_tensor_for_exchange,
        ), TempNamedDir(tmp_path_dist_ckpt / 'mock_dir') as ckpt_dir_A:
            _ = load_strategy.load(sharded_state_dict, ckpt_dir_A)

        # Each rank is expected to do 7 * 10 empty allocations
        assert len(mem_alloc) == 7 * 10
        # Peak mem usage should be within 4MB (single tensor)
        assert max(mem_alloc) - mem_alloc_start < 4.01 * megabytes, (
            max(mem_alloc),
            mem_alloc_start,
        )

        Utils.destroy_model_parallel()

    def test_only_necessary_exchanges_performed_during_load(self, tmp_path_dist_ckpt):
        Utils.initialize_model_parallel(2, 1)

        # State dict with 2 expected exchanges
        sharded_state_dict_baseline_two_exchanges = {
            'needed_by_all_A': ShardedTensor.from_rank_offsets(
                'needed_by_all_A',
                torch.ones(4, dtype=torch.float, device='cuda'),
                replica_id=Utils.rank,
            ),
            'needed_by_all_B': ShardedTensor.from_rank_offsets(
                'needed_by_all_B',
                torch.ones(4, dtype=torch.float, device='cuda'),
                replica_id=Utils.rank,
            ),
        }
        # State dict with 1 expected exchange
        sharded_state_dict_baseline_one_exchange = {
            'needed_by_all': sharded_state_dict_baseline_two_exchanges['needed_by_all_A']
        }
        # State dict with 1 expected exchanges even though there are 2 tensors to load (1 is unique for each rank)
        sharded_state_dict_test_one_exchange = sharded_state_dict_baseline_one_exchange.copy()
        sharded_state_dict_test_one_exchange['unique'] = ShardedTensor.from_rank_offsets(
            'unique',
            torch.ones(4, dtype=torch.float, device='cuda'),
            (0, Utils.rank, Utils.world_size),
        )

        expected_call_counts: List[Tuple[ShardedStateDict, int]] = [
            (sharded_state_dict_baseline_one_exchange, 1),
            (sharded_state_dict_baseline_two_exchanges, 2),
            (sharded_state_dict_test_one_exchange, 1),
        ]

        mock_strategy = MockLoadStrategy()
        with TempNamedDir(tmp_path_dist_ckpt / 'mock_dir') as ckpt_dir:
            for sharded_state_dict, expected_count in expected_call_counts:
                load_strategy = FullyParallelLoadStrategyWrapper(
                    mock_strategy, None, do_cache_distribution=True, exchange_algo='broadcast'
                )
                with mock.patch(
                    'megatron.core.dist_checkpointing.strategies.fully_parallel.torch.distributed.broadcast'
                ) as broadcast_mock:
                    _ = load_strategy.load(sharded_state_dict, ckpt_dir)
                    assert broadcast_mock.call_count == expected_count

        Utils.destroy_model_parallel()

    def test_broadcast_sharded_objects(self, tmp_path_dist_ckpt):

        sharded_state_dict = {
            f'Obj_{i}': ShardedObject(f'Obj_{i}', None, (1,), (0,), replica_id=abs(Utils.rank - i))
            for i in range(Utils.world_size)
        }

        with TempNamedDir(tmp_path_dist_ckpt / 'test_broadcast_sharded_objects') as ckpt_dir:
            load_strategy = MockLoadStrategy()
            load_strategy = FullyParallelLoadStrategyWrapper(load_strategy, None)

            loaded_state_dict = load_strategy.load(sharded_state_dict, ckpt_dir)

            # each rank is supposed to only load obj_rank because of how replica_id is set
            assert load_strategy.base_strategy.load_keys == set({f'Obj_{Utils.rank}'})

            # since each rank only loaded their Obj they were broadcasted
            assert set(sharded_state_dict.keys()) == set(loaded_state_dict.keys())