test_cached_metadata.py 3.57 KB
Newer Older
xingjinliang's avatar
xingjinliang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.

import pickle
from copy import deepcopy
from dataclasses import fields

import torch

from megatron.core.dist_checkpointing import ShardedTensor, load, save
from megatron.core.dist_checkpointing.dict_utils import diff
from megatron.core.dist_checkpointing.serialization import get_default_save_sharded_strategy
from megatron.core.dist_checkpointing.strategies.async_utils import AsyncCallsQueue
from tests.unit_tests.dist_checkpointing import TempNamedDir
from tests.unit_tests.test_utilities import Utils


class TestCachedMetadata:
    def setup_method(self, method):
        pass

    def teardown_method(self, method):
        Utils.destroy_model_parallel()

    def test_cached_metadata(self, tmp_path_dist_ckpt):
        Utils.initialize_model_parallel(2, 4)

        sharded_state_dict_non_cached = {
            'sd_keyA': ShardedTensor.from_rank_offsets(
                'keyA', torch.ones(2, 4), replica_id=Utils.rank
            ),
            'sd_keyB': ShardedTensor.from_rank_offsets(
                'keyB', torch.ones(3, 5, 7), replica_id=Utils.world_size - Utils.rank - 1
            ),
        }

        sharded_state_dict_cached = {
            'sd_keyA': ShardedTensor.from_rank_offsets(
                'keyA', torch.ones(2, 4), replica_id=Utils.rank
            ),
            'sd_keyB': ShardedTensor.from_rank_offsets(
                'keyB', torch.ones(3, 5, 7), replica_id=Utils.world_size - Utils.rank - 1
            ),
        }

        loaded_non_cached, loaded_cached = None, None
        md_non_cached, md_cached = None, None
        with TempNamedDir(tmp_path_dist_ckpt / 'ckpt_dir') as ckpt_dir:
            save(sharded_state_dict_non_cached, ckpt_dir, async_sharded_save=False)
            loaded_non_cached = load(sharded_state_dict_non_cached, ckpt_dir)
            md_path = ckpt_dir / '.metadata'
            with md_path.open('rb') as f:
                md_non_cached = pickle.load(f)

        save_strategy = deepcopy(get_default_save_sharded_strategy())
        save_strategy.use_cached_ckpt_structure = True
        # Run over 3 iterations with cached metadata enabled
        # The 3rd iteration will run with cached metadata
        # `ckpt_dir` at the 3rd iteration 2 will be maintained for comparison
        ckpt_dir = None
        for i in range(3):
            ckpt_dir = TempNamedDir(tmp_path_dist_ckpt / f'ckpt_dir_${i}_cached')
            save(
                sharded_state_dict_cached,
                ckpt_dir.__enter__(),
                save_strategy,
                async_sharded_save=False,
            )
            if i < 2:
                ckpt_dir.cleanup()
        loaded_cached = load(sharded_state_dict_cached, ckpt_dir.__enter__())
        md_path = ckpt_dir.__enter__() / '.metadata'

        with md_path.open('rb') as f:
            md_cached = pickle.load(f)

        # Check loaded state dict
        diffs = diff(loaded_non_cached, loaded_cached)

        assert not any(
            len(x) for x in diffs
        ), 'Cached metadata doesn\'t produce the same state_dict in loading'
        # Check metadata recorded in .metadata, torch.distributed.metadata.Metadata
        for field in fields(md_non_cached):
            if field.name not in ['storage_data', 'storage_meta']:
                diffs = diff(getattr(md_non_cached, field.name), getattr(md_cached, field.name))
                assert not any(
                    len(x) for x in diffs
                ), f'{field.name} is different in metadata from non-cached, cached metadata impls'
        ckpt_dir.cleanup()
        Utils.destroy_model_parallel()