test_checkpoints.py 10.6 KB
Newer Older
hepj987's avatar
hepj987 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import io
import os
import pytest
from pathlib import Path

from parameterized import parameterized
from megatron.testing_utils import (
    CaptureStdout,
    TestCasePlus,
    execute_subprocess_async,
    get_gpu_count,
    require_deepspeed,
    require_torch_gpu,
    require_torch_multi_gpu,
    set_seed
)

set_seed(42)


def parameterized_custom_name_func(func, param_num, param):
    # customize the test name generator function as we want both params to appear in the sub-test
    # name, as by default it shows only the first param
    param_based_name = parameterized.to_safe_name("_to_".join(str(x) for x in param.args))
    return f"{func.__name__}_{param_based_name}"

params = [
    # TP_PP_DP
    ["1_1_1", "1_1_1"],
    ["2_1_1", "1_1_1"],
    ["1_2_1", "1_1_1"],
    ["1_1_2", "1_1_1"],

    ["2_1_1", "2_1_1"],
    ["1_1_1", "2_1_1"],
    ["1_1_1", "1_2_1"],
    ["1_1_1", "1_1_2"],

    ["1_1_2", "1_1_2"],
    ["1_1_2", "2_1_1"],
    ["1_1_2", "1_2_1"],

    ["1_2_1", "1_2_1"],
    ["1_2_1", "2_1_1"],
    ["1_2_1", "1_1_2"],

    ["2_1_1", "2_1_1"],
    ["2_1_1", "1_2_1"],
    ["2_1_1", "1_1_2"],

    ["2_2_2", "1_1_1"],
    ["2_2_2", "2_2_2"],
    ["1_1_1", "2_2_2"],

    ["1_1_8", "2_2_2"],

]

def get_launcher(num_gpus):
    # 1. explicitly set --num_nodes=1 just in case these tests end up run on a multi-node setup
    # - it won't be able to handle that
    return f"deepspeed --num_nodes 1 --num_gpus {num_gpus}".split()

@require_deepspeed
@require_torch_gpu
class MegDSTestCheckpoints(TestCasePlus):
    """ """

    def setUp(self):
        super().setUp()

        # at times magatron fails to build kernels and doesn't remove the lock file, which makes
        # subsequent runs hang - so make sure there is no lock when starting the testing
        meg_lock_file_path = self.repo_root_dir_str + "/megatron/fused_kernels/build/lock"
        if os.path.exists(meg_lock_file_path):
            os.unlink(meg_lock_file_path)

    def get_config(self, output_dir, tp_size, pp_size, dp_size):
        data_dir = f"{self.data_dir}/gpt2"

        num_gpus = pp_size * tp_size * dp_size
        print(f"Using {num_gpus} GPUs")

        n_samples = 300 # about 56 iterations

        exit_interval = 20 # some samples in the first half and then some more in the 2nd half after resume
        seq_len = 128

        # XXX: for now while testing shapes make it really short and fast
        exit_interval = 1
        seq_len = 8


        # common/shared configs

        ds_args = f"""
                --deepspeed
                --deepspeed_config {self.test_file_dir_str}/ds_config_bf16.json
                --zero-stage 0
                --deepspeed-activation-checkpointing
        """.split()

        args = f"""
                --tensor-model-parallel-size {tp_size}
                --pipeline-model-parallel-size {pp_size}
                --distributed-backend nccl

                --log-interval 1
                --save-interval 1
                --eval-interval 10
                --eval-iters 1
                --checkpoint-activations
                --partition-activations
                --exit-interval {exit_interval}

                --merge-file {data_dir}/gpt2-tiny-merges.txt
                --vocab-file {data_dir}/gpt2-tiny-vocab.json
                --save {output_dir}/checkpoints
                --load {output_dir}/checkpoints
                --data-path {data_dir}/meg-gpt2-openwebtext_text_document
                --tensorboard-dir {output_dir}/tensorboard
                --tensorboard-queue-size 5
                --log-timers-to-tensorboard
                --log-batch-size-to-tensorboard
                --log-validation-ppl-to-tensorboard

                --num-layers 2
                --hidden-size 8
                --num-attention-heads 2
                --seq-length {seq_len}
                --max-position-embeddings 8
                --micro-batch-size 1
                --global-batch-size 16
                --train-samples {n_samples}

                --embed-layernorm
                --position-embedding-type alibi

                --optimizer adam
                --adam-beta1 0.9
                --adam-beta2 0.95
                --adam-eps 1e-8
                --lr 1e-4
                --lr-warmup-samples 5
                --lr-decay-samples 6
                --clip-grad 1.0
                --weight-decay 1e-1
                --bf16

                --log-level debug
                --log-level-replica info
        """.split()


        # XXX: fails to handle:
        #--embed-layernorm
        #
# stderr: RuntimeError: Error(s) in loading state_dict for VocabParallelEmbedding:
# stderr:         size mismatch for norm.weight: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([64]).
# stderr:         size mismatch for norm.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([64]).

        return args, ds_args, num_gpus


    def train_checkpoint(self, output_dir, tp_size=1, pp_size=1, dp_size=1):
        src_dir = self.src_dir
        script = [f"{src_dir}/pretrain_gpt.py"]

        args, ds_args, num_gpus = self.get_config(output_dir, tp_size, pp_size, dp_size)
        launcher = get_launcher(num_gpus)
        cmd = launcher + script + args + ds_args
        # keep for quick debug
        #print(" ".join([f"\nPYTHONPATH={self.src_dir_str}"] +cmd)); die

        # 1. test training from scratch (no checkpoint)
        with CaptureStdout() as cs:
            execute_subprocess_async(cmd, env=self.get_env())

        # test deepspeed is running
        self.assertIn("DeepSpeed info", cs.out)

        # test reports
        self.assertIn("consumed samples", cs.out)

        # test there should be no checkpoint this round
        self.assertIn(f"Unable to find latest file at {output_dir}/checkpoints/latest", cs.out)

        # test checkpoint saving
        self.assertIn("successfully saved checkpoint at iteration", cs.out)

    def convert_checkpoint_to_universal(self, output_dir, step):
        cmd = f"""
            python tools/convert_checkpoint/ds_to_universal.py
            --input_folder  {output_dir}/checkpoints/global_step{step}
            --output_folder {output_dir}/checkpoints/global_step{step}_universal
        """.split()
        # keep for quick debug
        # print(" ".join([f"\nPYTHONPATH={self.src_dir_str}"] +cmd)); die

        with CaptureStdout() as cs:
            execute_subprocess_async(cmd, env=self.get_env())

        self.assertIn("Convert DeepSpeed Checkpoint to Universal Checkpoint", cs.out)

    def resume_from_checkpoint(self, output_dir, tp_size=1, pp_size=1, dp_size=1):
        src_dir = self.src_dir
        script = [f"{src_dir}/pretrain_gpt.py"]

        args, ds_args, num_gpus = self.get_config(output_dir, tp_size, pp_size, dp_size)
        launcher = get_launcher(num_gpus)
        cmd = launcher + script + args + ds_args
        # keep for quick debug
        # print(" ".join([f"\nPYTHONPATH={self.src_dir_str}"] +cmd)); die

        with CaptureStdout() as cs:
            execute_subprocess_async(cmd, env=self.get_env())

        # test checkpoint loading
        self.assertIn(f"successfully loaded checkpoint from {output_dir}/checkpoints", cs.out)

        # test reports
        self.assertIn("consumed samples", cs.out)

        # test checkpoint saving
        self.assertIn("successfully saved checkpoint at iteration", cs.out)

    def resume_from_universal_checkpoint(self, output_dir, tp_size=1, pp_size=1, dp_size=1):
        src_dir = self.src_dir
        script = [f"{src_dir}/pretrain_gpt.py"]

        args, ds_args, num_gpus = self.get_config(output_dir, tp_size, pp_size, dp_size)
        launcher = get_launcher(num_gpus)
        cmd = launcher + script + args + ds_args + ["--universal-checkpoint"]
        # keep for quick debug
        #print(" ".join([f"\nPYTHONPATH={self.src_dir_str}"] +cmd)); die

        with CaptureStdout() as cs:
            execute_subprocess_async(cmd, env=self.get_env())

        # test checkpoint loading
        self.assertIn(f"successfully loaded checkpoint from {output_dir}/checkpoints", cs.out)

        # test reports
        self.assertIn("consumed samples", cs.out)

        # test checkpoint saving
        self.assertIn("successfully saved checkpoint at iteration", cs.out)


    @require_torch_multi_gpu
    @parameterized.expand(params, name_func=parameterized_custom_name_func)
    def test_checkpoint_reshaping_main(self, src, tgt):
        # this test needs at least 2 gpus - if there are more gpus it will do more extensive testing

        tp_size_src, pp_size_src, dp_size_src = list(map(int, src.split('_')))
        tp_size_tgt, pp_size_tgt, dp_size_tgt = list(map(int, tgt.split('_')))

        n_gpus = get_gpu_count()
        n_gpus_src = tp_size_src * pp_size_src * dp_size_src
        n_gpus_tgt = tp_size_tgt * pp_size_tgt * dp_size_tgt

        if n_gpus_src > n_gpus:
            pytest.skip(f"the test requires {n_gpus_src} gpus for source topology but have only {n_gpus}")
        if n_gpus_tgt > n_gpus:
            pytest.skip(f"the test requires {n_gpus_tgt} gpus for target topology but have only {n_gpus}")

        output_dir = self.get_auto_remove_tmp_dir("./xxx", after=False)

        # 1. train with initial topology defined in the first arg of params
        self.train_checkpoint(output_dir, tp_size=tp_size_src , pp_size=pp_size_src , dp_size=dp_size_src )

        # 2. convert checkpoint to universal checkpoint (topology )
        self.convert_checkpoint_to_universal(output_dir=output_dir, step=1)

        # 3. check we can resume training from a reshaped checkpoint to the target topology - the last arg of params
        self.resume_from_universal_checkpoint(output_dir, tp_size=tp_size_tgt, pp_size=pp_size_tgt, dp_size=dp_size_tgt)


    @require_torch_multi_gpu
    def test_checkpoint_reshaping_empty_dir(self):

        output_dir = self.get_auto_remove_tmp_dir() # "./xxx", after=False)
        with self.assertRaises(RuntimeError) as context:
            self.convert_checkpoint_to_universal(output_dir=output_dir, step=1)