test_nonpersistent.py 5.58 KB
Newer Older
silencealiang's avatar
silencealiang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.

import filecmp
import os
from unittest import mock

import pytest

from megatron.training.arguments import parse_args
from megatron.training.checkpointing import (
    _NON_PERSISTENT_CKPT_SUBDIR,
    load_checkpoint,
    save_checkpoint,
)
from tests.unit_tests.dist_checkpointing import (
    TempNamedDir,
    init_basic_mock_args,
    init_checkpointing_mock_args,
    setup_model_and_optimizer,
)
from tests.unit_tests.test_utilities import Utils


class TestNonPersistentSaveAndLoad:
    def setup_method(self, method):
        pass

    def teardown_method(self, method):
        Utils.destroy_model_parallel()

    @pytest.mark.parametrize(('tp,pp'), [(2, 4)])
    def test_basic_save_load_scenarios(self, tmp_path_dist_ckpt, tp, pp):
        Utils.initialize_model_parallel(tp, pp)
        num_floating_point_operations_so_far = 0
        model, optimizer = setup_model_and_optimizer(1, tp, pp)
        opt_param_scheduler = None

        mock_args = parse_args(ignore_unknown_args=True)
        with TempNamedDir(
            tmp_path_dist_ckpt / "test_non_persistent"
        ) as non_persistent_ckpt_dir, mock.patch(
            'megatron.training.checkpointing.get_args', new=lambda: mock_args
        ), mock.patch(
            "megatron.training.checkpointing.update_num_microbatches"
        ):
            init_basic_mock_args(mock_args, tp, pp)
            init_checkpointing_mock_args(mock_args, non_persistent_ckpt_dir)
            mock_args.non_persistent_ckpt_type = "global"

            save_checkpoint(
                2,
                model,
                optimizer,
                opt_param_scheduler,
                num_floating_point_operations_so_far,
                {},
                non_persistent_ckpt=True,
            )
            save_checkpoint(
                3, model, optimizer, opt_param_scheduler, num_floating_point_operations_so_far, {}
            )
            save_checkpoint(
                4,
                model,
                optimizer,
                opt_param_scheduler,
                num_floating_point_operations_so_far,
                {},
                non_persistent_ckpt=True,
            )
            iteration, _ = load_checkpoint(model, optimizer, opt_param_scheduler)
            assert iteration == 4
            save_checkpoint(
                6, model, optimizer, opt_param_scheduler, num_floating_point_operations_so_far, {}
            )
            iteration, _ = load_checkpoint(model, optimizer, opt_param_scheduler)
            assert iteration == 6
            save_checkpoint(
                8,
                model,
                optimizer,
                opt_param_scheduler,
                num_floating_point_operations_so_far,
                {},
                non_persistent_ckpt=True,
            )
            iteration, _ = load_checkpoint(model, optimizer, opt_param_scheduler)
            assert iteration == 8
            assert "iter_0000003" in os.listdir(non_persistent_ckpt_dir)
            assert "iter_0000006" in os.listdir(non_persistent_ckpt_dir)
            assert "iter_0000002" not in os.listdir(
                os.path.join(non_persistent_ckpt_dir, _NON_PERSISTENT_CKPT_SUBDIR)
            )
            assert "iter_0000004" in os.listdir(
                os.path.join(non_persistent_ckpt_dir, _NON_PERSISTENT_CKPT_SUBDIR)
            )
            assert "iter_0000008" in os.listdir(
                os.path.join(non_persistent_ckpt_dir, _NON_PERSISTENT_CKPT_SUBDIR)
            )
            ckpt_dirs = [
                "iter_0000003",
                "iter_0000006",
                _NON_PERSISTENT_CKPT_SUBDIR + "/iter_0000004",
                _NON_PERSISTENT_CKPT_SUBDIR + "/iter_0000008",
            ]
            for ckpt_a in ckpt_dirs:
                for ckpt_b in ckpt_dirs:
                    for filename in os.listdir(os.path.join(non_persistent_ckpt_dir, ckpt_a)):
                        if filename != "common.pt" and filename != ".metadata":
                            assert filecmp.cmp(
                                os.path.join(non_persistent_ckpt_dir, ckpt_a, filename),
                                os.path.join(non_persistent_ckpt_dir, ckpt_b, filename),
                                shallow=False,
                            ), [filename, ckpt_a, ckpt_b]
        Utils.destroy_model_parallel()


class TestLegacySaveAndLoad:
    @pytest.mark.parametrize(('tp,pp'), [(2, 4)])
    def test_basic_save_load_scenario(self, tmp_path_dist_ckpt, tp, pp):
        Utils.initialize_model_parallel(tp, pp)
        num_floating_point_operations_so_far = 0
        model, optimizer = setup_model_and_optimizer(1, tp, pp)
        opt_param_scheduler = None

        mock_args = parse_args(ignore_unknown_args=True)
        with TempNamedDir(tmp_path_dist_ckpt / "test_legacy") as legacy_ckpt_dir, mock.patch(
            'megatron.training.checkpointing.get_args', new=lambda: mock_args
        ), mock.patch("megatron.training.checkpointing.update_num_microbatches"):
            init_basic_mock_args(mock_args, tp, pp)
            init_checkpointing_mock_args(mock_args, legacy_ckpt_dir)

            save_checkpoint(
                2, model, optimizer, opt_param_scheduler, num_floating_point_operations_so_far, {}
            )
            iteration, _ = load_checkpoint(model, optimizer, opt_param_scheduler)
            assert iteration == 2
            assert "iter_0000002" in os.listdir(legacy_ckpt_dir)

        Utils.destroy_model_parallel()