test_deepspeed.py 46.2 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
import dataclasses
16
import io
17
import itertools
18
import json
19
20
import os
import unittest
21
from copy import deepcopy
22

23
import datasets
24
from parameterized import parameterized
25

Stas Bekman's avatar
Stas Bekman committed
26
from tests.trainer.test_trainer import TrainerIntegrationCommon  # noqa
27
from transformers import AutoModel, TrainingArguments, is_torch_available, logging
28
from transformers.deepspeed import HfDeepSpeedConfig, is_deepspeed_available, unset_hf_deepspeed_config
29
from transformers.testing_utils import (
30
    CaptureLogger,
31
    CaptureStd,
32
    CaptureStderr,
33
    LoggingLevel,
34
35
36
    TestCasePlus,
    execute_subprocess_async,
    get_gpu_count,
37
    mockenv_context,
38
    require_deepspeed,
39
    require_optuna,
40
41
42
43
    require_torch_gpu,
    require_torch_multi_gpu,
    slow,
)
44
from transformers.trainer_utils import get_last_checkpoint, set_seed
45
from transformers.utils import WEIGHTS_NAME, is_torch_bf16_gpu_available
46

47

48
if is_torch_available():
Stas Bekman's avatar
Stas Bekman committed
49
50
51
52
53
    from tests.trainer.test_trainer import (  # noqa
        RegressionModelConfig,
        RegressionPreTrainedModel,
        get_regression_trainer,
    )
54
55


56
set_seed(42)
57

58
59
60
# default torch.distributed port
DEFAULT_MASTER_PORT = "10999"

61
T5_SMALL = "t5-small"
62
T5_TINY = "patrickvonplaten/t5-tiny-random"
63
GPT2_TINY = "sshleifer/tiny-gpt2"
64
65


66
67
68
69
70
def load_json(path):
    with open(path) as f:
        return json.load(f)


Stas Bekman's avatar
Stas Bekman committed
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
def get_master_port(real_launcher=False):
    """
    When using a single gpu launcher emulation (i.e. not deepspeed or python -m torch.distributed)
    the issue is that once the port is tied it can't be used anywhere else outside of this process,
    since torch.dist doesn't free the port until the process exits. Therefore for the sake of being
    able to run both emulated launcher and normal launcher tests we need 2 distinct ports.

    This function will give the right port in the right context. For real launcher it'll give the
    base port, for emulated launcher it'll give the base port + 1. In both cases a string is
    returned.

    Args:
        `real_launcher`: whether a real launcher is going to be used, or the emulated one

    """

    master_port_base = os.environ.get("DS_TEST_PORT", DEFAULT_MASTER_PORT)
    if not real_launcher:
        master_port_base = str(int(master_port_base) + 1)
    return master_port_base


93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
def require_deepspeed_aio(test_case):
    """
    Decorator marking a test that requires deepspeed aio (nvme)
    """
    if not is_deepspeed_available():
        return unittest.skip("test requires deepspeed")(test_case)

    import deepspeed
    from deepspeed.ops.aio import AsyncIOBuilder

    if not deepspeed.ops.__compatible_ops__[AsyncIOBuilder.NAME]:
        return unittest.skip("test requires deepspeed async-io")(test_case)
    else:
        return test_case


109
110
if is_deepspeed_available():
    from deepspeed.utils import logger as deepspeed_logger  # noqa
111
    from deepspeed.utils.zero_to_fp32 import load_state_dict_from_zero_checkpoint
112
    from transformers.deepspeed import deepspeed_config, is_deepspeed_zero3_enabled  # noqa
113

114
115
116
117
118
119
120

def get_launcher(distributed=False):
    # 1. explicitly set --num_nodes=1 just in case these tests end up run on a multi-node setup
    # - it won't be able to handle that
    # 2. for now testing with just 2 gpus max (since some quality tests may give different
    # results with mode gpus because we use very little data)
    num_gpus = min(2, get_gpu_count()) if distributed else 1
Stas Bekman's avatar
Stas Bekman committed
121
    master_port = get_master_port(real_launcher=True)
122
    return f"deepspeed --num_nodes 1 --num_gpus {num_gpus} --master_port {master_port}".split()
123
124


125
126
ZERO2 = "zero2"
ZERO3 = "zero3"
127
128
129
130

FP16 = "fp16"
BF16 = "bf16"

131
stages = [ZERO2, ZERO3]
132
if is_torch_bf16_gpu_available():
133
134
135
136
137
138
139
140
141
142
143
144
145
146
    dtypes = [FP16, BF16]
else:
    dtypes = [FP16]


def parameterized_custom_name_func(func, param_num, param):
    # customize the test name generator function as we want both params to appear in the sub-test
    # name, as by default it shows only the first param
    param_based_name = parameterized.to_safe_name("_".join(str(x) for x in param.args))
    return f"{func.__name__}_{param_based_name}"


# Cartesian-product of zero stages with models to test
params = list(itertools.product(stages, dtypes))
147
148


149
150
151
152
153
154
155
156
157
158
@require_deepspeed
@require_torch_gpu
class CoreIntegrationDeepSpeed(TestCasePlus, TrainerIntegrationCommon):
    """
    Testing non-Trainer DeepSpeed integration
    """

    def setUp(self):
        super().setUp()

Stas Bekman's avatar
Stas Bekman committed
159
        master_port = get_master_port(real_launcher=False)
160
        self.dist_env_1_gpu = dict(
161
            MASTER_ADDR="localhost", MASTER_PORT=master_port, RANK="0", LOCAL_RANK="0", WORLD_SIZE="1"
162
163
        )

164
165
166
167
168
169
    def tearDown(self):
        super().tearDown()

        # reset the ds config global so that tests state doesn't leak
        unset_hf_deepspeed_config()

170
171
    def test_init_zero3_fp16(self):
        # test that zero.Init() works correctly under zero3/fp16
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
        ds_config = {
            "train_batch_size": 1,
            "zero_optimization": {
                "stage": 3,
            },
        }

        dschf = HfDeepSpeedConfig(ds_config)

        self.assertTrue(dschf.is_zero3())
        self.assertTrue(is_deepspeed_zero3_enabled())

        with LoggingLevel(logging.INFO):
            with mockenv_context(**self.dist_env_1_gpu):
                logger = logging.get_logger("transformers.modeling_utils")
                with CaptureLogger(logger) as cl:
                    AutoModel.from_pretrained(T5_TINY)
        self.assertIn("Detected DeepSpeed ZeRO-3", cl.out)

        # now remove zero optimization
        del ds_config["zero_optimization"]
        dschf = HfDeepSpeedConfig(ds_config)

        self.assertFalse(dschf.is_zero3())
        self.assertFalse(is_deepspeed_zero3_enabled())

        with LoggingLevel(logging.INFO):
            with mockenv_context(**self.dist_env_1_gpu):
                logger = logging.get_logger("transformers.modeling_utils")
                with CaptureLogger(logger) as cl:
                    AutoModel.from_pretrained(T5_TINY)
        self.assertNotIn("Detected DeepSpeed ZeRO-3", cl.out)


206
class TrainerIntegrationDeepSpeedWithCustomConfig(TestCasePlus):
207
208
    def setUp(self):
        super().setUp()
209
210
211
212
213

        args = TrainingArguments(".")
        self.n_epochs = args.num_train_epochs
        self.batch_size = args.train_batch_size

Stas Bekman's avatar
Stas Bekman committed
214
        master_port = get_master_port(real_launcher=False)
215
        self.dist_env_1_gpu = dict(
216
            MASTER_ADDR="localhost", MASTER_PORT=master_port, RANK="0", LOCAL_RANK="0", WORLD_SIZE="1"
217
        )
218

219
220
221
222
        self.ds_config_file = dict(
            zero2=f"{self.test_file_dir_str}/ds_config_zero2.json",
            zero3=f"{self.test_file_dir_str}/ds_config_zero3.json",
        )
223
224
225

        # use self.get_config_dict(stage) to use these to ensure the original is not modified
        with io.open(self.ds_config_file[ZERO2], "r", encoding="utf-8") as f:
226
            config_zero2 = json.load(f)
227
        with io.open(self.ds_config_file[ZERO3], "r", encoding="utf-8") as f:
228
            config_zero3 = json.load(f)
229
            # The following setting slows things down, so don't enable it by default unless needed by a test.
230
            # It's in the file as a demo for users since we want everything to work out of the box even if slower.
231
232
            config_zero3["zero_optimization"]["stage3_gather_16bit_weights_on_model_save"] = False

233
234
235
236
237
        self.ds_config_dict = dict(
            zero2=config_zero2,
            zero3=config_zero3,
        )

238
239
240
241
242
243
    def tearDown(self):
        super().tearDown()

        # reset the ds config global so that tests state doesn't leak
        unset_hf_deepspeed_config()

244
245
246
    def get_config_dict(self, stage):
        # As some tests modify the dict, always make a copy
        return deepcopy(self.ds_config_dict[stage])
247

248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270

@require_deepspeed
@require_torch_gpu
class TrainerIntegrationDeepSpeed(TrainerIntegrationDeepSpeedWithCustomConfig, TrainerIntegrationCommon):
    """

    This class is for testing directly via get_regression_trainer

    It mixes in `TrainerIntegrationCommon` which already has a lot of helper validation methods
    which we can re-use here.

    Important: this class' setup can only work with a single gpu because it runs within the current
    pytest worker. For multi-gpu tests use TestDeepSpeedWithLauncher.

    Note: if any of the tests of this class get run there will be at least one gpu occupied by them
    until this pytest worker exits. This is because the gpu memory allocated by the cuda-kernels
    won't be released until this pytest worker exits.

    This may appear as some run-away tests if you watch `nvidia-smi` while other tests that fork new
    processes are run. So there will be one or two "stale" processes reported in `nvidia-smi`. This
    is not a bug.
    """

271
    # --- These tests are enough to run on one of zero stages --- #
272

273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
    def test_hf_ds_config_mismatch(self):
        ds_config = self.get_config_dict(ZERO2)

        # Purposefully configure these values to mismatch TrainingArguments values.
        # This currently doesn't cover all keys (but it could)
        per_device_train_batch_size = 2
        ds_config["train_micro_batch_size_per_gpu"] = per_device_train_batch_size + 2

        ds_config["train_batch_size"] = 1000

        gradient_accumulation_steps = 2
        ds_config["gradient_accumulation_steps"] = gradient_accumulation_steps + 2

        max_grad_norm = 1.0
        ds_config["gradient_clipping"] = max_grad_norm + 0.1

        adam_beta1, adam_beta2 = 0.9, 0.99
        ds_config["optimizer"]["params"]["betas"] = [adam_beta1 - 0.1, adam_beta2 - 0.1]

        fp16 = True
        ds_config["fp16"]["enabled"] = not fp16

        keys = [
            "per_device_train_batch_size",
            "train_batch_size",
            "gradient_accumulation_steps",
            "max_grad_norm",
            "betas",
            "fp16",
        ]

        with mockenv_context(**self.dist_env_1_gpu):
            trainer = get_regression_trainer(
                local_rank=0,
                fp16=fp16,
                deepspeed=ds_config,
                per_device_train_batch_size=per_device_train_batch_size,
                gradient_accumulation_steps=gradient_accumulation_steps,
                max_grad_norm=max_grad_norm,
                adam_beta1=adam_beta1,
                adam_beta2=adam_beta2,
            )
            with self.assertRaises(Exception) as context:
                trainer.train()

        for key in keys:
            self.assertTrue(
                key in str(context.exception),
                f"{key} is not in the exception message:\n{context.exception}",
            )

324
325
326
327
328
329
330
331
332
    # Test various combos
    # 1. DS scheduler + DS optimizer: this is already tested by most other tests
    # 2. HF scheduler + HF optimizer:
    # 3. DS scheduler + HF optimizer:
    # 4. HF scheduler + DS optimizer:

    def test_hf_scheduler_hf_optimizer(self):
        a = 0
        with mockenv_context(**self.dist_env_1_gpu):
333
334
335
            ds_config_zero2_dict = self.get_config_dict(ZERO2)
            del ds_config_zero2_dict["optimizer"]  # force default HF Trainer optimizer
            del ds_config_zero2_dict["scheduler"]  # force default HF Trainer scheduler
336
            ds_config_zero2_dict["zero_optimization"]["offload_optimizer"]["device"] = "none"
337
            ds_config_zero2_dict["fp16"]["initial_scale_power"] = 1  # force optimizer on the first step
338
            trainer = get_regression_trainer(a=a, local_rank=0, fp16=True, deepspeed=ds_config_zero2_dict)
339
340
341
342
343
344
345
            trainer.train()
        new_a = trainer.model.a.item()
        self.assertNotEqual(new_a, a)

    def test_ds_scheduler_hf_optimizer(self):
        a = 0
        with mockenv_context(**self.dist_env_1_gpu):
346
347
            ds_config_zero2_dict = self.get_config_dict(ZERO2)
            del ds_config_zero2_dict["optimizer"]  # force default HF Trainer optimizer
348
            ds_config_zero2_dict["zero_optimization"]["offload_optimizer"]["device"] = "none"
349
            ds_config_zero2_dict["fp16"]["initial_scale_power"] = 1  # force optimizer on the first step
350
            trainer = get_regression_trainer(a=a, local_rank=0, fp16=True, deepspeed=ds_config_zero2_dict)
351
352
353
354
355
            trainer.train()
        new_a = trainer.model.a.item()
        self.assertNotEqual(new_a, a)

    def test_hf_scheduler_ds_optimizer(self):
356
        a = 0
357
        with mockenv_context(**self.dist_env_1_gpu):
358
359
            ds_config_zero2_dict = self.get_config_dict(ZERO2)
            del ds_config_zero2_dict["scheduler"]  # force default HF Trainer scheduler
360
            ds_config_zero2_dict["zero_optimization"]["offload_optimizer"]["device"] = "none"
361
            ds_config_zero2_dict["fp16"]["initial_scale_power"] = 1  # force optimizer on the first step
362
            trainer = get_regression_trainer(local_rank=0, fp16=True, deepspeed=ds_config_zero2_dict)
363
364
365
            trainer.train()
        new_a = trainer.model.a.item()
        self.assertNotEqual(new_a, a)
366

367
    @require_deepspeed_aio
368
369
370
371
372
373
374
375
376
    def test_stage3_nvme_offload(self):
        with mockenv_context(**self.dist_env_1_gpu):
            # this actually doesn't have to be on NVMe, any storage will do since this test only
            # runs a simple check that we can use some directory as if it were NVMe
            nvme_path = self.get_auto_remove_tmp_dir()
            nvme_config = dict(device="nvme", nvme_path=nvme_path)
            ds_config_zero3_dict = self.get_config_dict(ZERO3)
            ds_config_zero3_dict["zero_optimization"]["offload_optimizer"] = nvme_config
            ds_config_zero3_dict["zero_optimization"]["offload_param"] = nvme_config
377
            trainer = get_regression_trainer(local_rank=0, fp16=True, deepspeed=ds_config_zero3_dict)
378
            with CaptureLogger(deepspeed_logger) as cl:
379
                trainer.train()
380
            self.assertIn("DeepSpeed info", cl.out, "expected DeepSpeed logger output but got none")
381

382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
    @require_optuna
    def test_hyperparameter_search(self):
        with mockenv_context(**self.dist_env_1_gpu):
            ds_config_zero3_dict = self.get_config_dict(ZERO3)

            # hyperparameter_search requires model_init() to recreate the model for each trial
            def model_init():
                config = RegressionModelConfig(a=0, b=0, double_output=False)
                model = RegressionPreTrainedModel(config)
                return model

            trainer = get_regression_trainer(
                local_rank=0,
                fp16=True,
                model_init=model_init,
                deepspeed=ds_config_zero3_dict,
            )

            n_trials = 3
            with CaptureLogger(deepspeed_logger) as cl:
                with CaptureStd() as cs:
                    trainer.hyperparameter_search(direction="maximize", n_trials=n_trials)
            self.assertIn("DeepSpeed info", cl.out, "expected DeepSpeed logger output but got none")
            self.assertIn(f"Trial {n_trials-1} finished with value", cs.err, "expected hyperparameter_search output")
            self.assertIn("Best is trial", cs.err, "expected hyperparameter_search output")

408
409
    # --- These tests need to run on both zero stages --- #

410
411
    @parameterized.expand(params, name_func=parameterized_custom_name_func)
    def test_hf_optimizer_with_offload(self, stage, dtype):
412
        # non-DS optimizers can be used with ZERO-offload (as long as they have both CPU and GPU implementation (except LAMB))
413
414
415
        ds_config_dict = self.get_config_dict(stage)
        del ds_config_dict["optimizer"]  # force default HF Trainer optimizer
        # force cpu offload
416
        ds_config_dict["zero_optimization"]["offload_optimizer"]["device"] = "cpu"
417
        with mockenv_context(**self.dist_env_1_gpu):
418
419
420
            kwargs = dict(local_rank=0, deepspeed=ds_config_dict)
            kwargs[dtype] = True
            trainer = get_regression_trainer(**kwargs)
421
            with CaptureLogger(deepspeed_logger) as cl:
422
                trainer.train()
423
            self.assertIn("DeepSpeed info", cl.out, "expected DeepSpeed logger output but got none")
424

425
426
    @parameterized.expand(params, name_func=parameterized_custom_name_func)
    def test_fake_notebook_no_launcher(self, stage, dtype):
427
428
429
430
431
        # this setup emulates a notebook where a launcher needs to be emulated by hand

        # note that unittest resets sys.stdout each test, so `CaptureStd` will work here to capture
        # DeepSpeed log if this test happens to run first in this pytest worker. But it will fail if
        # it's run not as a first test as `sys.stdout` will no longer be the same. So we either have
432
433
        # to reset `deepspeed_logger.handlers[0].setStream(sys.stdout)` or directly capture from the deepspeed_logger.
        with mockenv_context(**self.dist_env_1_gpu):
434
435
436
437
            kwargs = dict(local_rank=0, deepspeed=self.get_config_dict(stage))
            kwargs[dtype] = True
            trainer = get_regression_trainer(**kwargs)

438
            with CaptureLogger(deepspeed_logger) as cl:
439
                trainer.train()
440
            self.assertIn("DeepSpeed info", cl.out, "expected DeepSpeed logger output but got none")
441

442
443
    @parameterized.expand(params, name_func=parameterized_custom_name_func)
    def test_early_get_last_lr(self, stage, dtype):
444
445
446
447
448
449
450
451
        # with deepspeed's fp16 and dynamic loss scale enabled the optimizer/scheduler steps may
        # not run for the first few dozen steps while loss scale is too large, and thus during
        # that time `get_last_lr` will fail if called during that warm up stage,
        #
        # setting `logging_steps=1` forces an early `trainer._maybe_log_save_evaluate()` which calls
        # `self.lr_scheduler.get_last_lr()` and originally it'd fail on the very first step.
        with mockenv_context(**self.dist_env_1_gpu):
            a = b = 0.0
452
            kwargs = dict(
453
454
455
456
                a=a,
                b=b,
                local_rank=0,
                train_len=8,
457
                deepspeed=self.get_config_dict(stage),
458
459
460
                per_device_train_batch_size=8,
                logging_steps=1,
            )
461
462
463
            kwargs[dtype] = True
            trainer = get_regression_trainer(**kwargs)

464
            trainer.train()
465
466
            post_train_a = trainer.model.a.item()

467
468
            # XXX: for some reason the following check fails with zero3/fp16 and any/bf16 - not a
            # broken but a different qualitative outcome - as if optimizer did run
469
470
471
472
            # oddly getting 1.0 for both a and b from 0.0 - there is a bug somewhere
            # print(trainer.model.a.item())
            # print(trainer.model.b.item())
            # need to investigate at some point
473
            if (stage == ZERO3 and dtype == FP16) or (dtype == BF16):
474
                return
475
476
477

            # it's enough that train didn't fail for this test, but we must check that
            # optimizer/scheduler didn't run (since if it did this test isn't testing the right thing)
478
            self.assertEqual(post_train_a, a)
479

480
481
    @parameterized.expand(params, name_func=parameterized_custom_name_func)
    def test_gradient_accumulation(self, stage, dtype):
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
        # this test measures that we get identical weights and similar loss with:
        # 1. per_device_train_batch_size=8, gradient_accumulation_steps=1
        # 2. per_device_train_batch_size=4, gradient_accumulation_steps=2
        # since the 2nd should produce the effective batch of 1st, with the same results
        #
        # I can get an identical loss for a small train_len=32, plus the power of the initial
        # dynamic loss scale value set to:
        #   "fp16.initial_scale_power": 1
        # plus having the same WarmupLR's warmup_min_lr == warmup_max_lr in the config file
        # but for some reason going to train_len=64 the weights, weights start to mismatch with this setup.
        # the culprit seems to be `initial_scale_power` - putting it back to its default 32 keeps the weights identical

        train_len = 64
        a = b = 0.0

497
498
499
500
501
502
503
        kwargs = dict(
            a=a,
            b=b,
            local_rank=0,
            train_len=train_len,
            deepspeed=self.get_config_dict(stage),
        )
504
        kwargs[dtype] = True
505

506
507
        with mockenv_context(**self.dist_env_1_gpu):
            no_grad_accum_trainer = get_regression_trainer(
508
509
                **kwargs,
                per_device_train_batch_size=16,
510
511
512
513
514
515
516
517
518
519
520
                gradient_accumulation_steps=1,
            )
            no_grad_accum_result = no_grad_accum_trainer.train()
            no_grad_accum_loss = no_grad_accum_result.training_loss
            no_grad_accum_a = no_grad_accum_trainer.model.a.item()
            no_grad_accum_b = no_grad_accum_trainer.model.b.item()
            # make sure the optimizer kicked in - if it hasn't changed from the original value of a then make train_len bigger
            self.assertNotEqual(no_grad_accum_a, a)

        with mockenv_context(**self.dist_env_1_gpu):
            yes_grad_accum_trainer = get_regression_trainer(
521
                **kwargs,
522
                per_device_train_batch_size=4,
523
                gradient_accumulation_steps=4,
524
525
526
527
528
529
530
            )
            yes_grad_accum_result = yes_grad_accum_trainer.train()
            yes_grad_accum_loss = yes_grad_accum_result.training_loss
            yes_grad_accum_a = yes_grad_accum_trainer.model.a.item()
            yes_grad_accum_b = yes_grad_accum_trainer.model.b.item()
            self.assertNotEqual(yes_grad_accum_a, a)

531
532
533
534
        # training with half the batch size but accumulation steps as 2 should give the same
        # weights, but sometimes get a slight difference still of 1e-6
        self.assertAlmostEqual(no_grad_accum_a, yes_grad_accum_a, places=5)
        self.assertAlmostEqual(no_grad_accum_b, yes_grad_accum_b, places=5)
535
536

        # see the note above how to get identical loss on a small bs
537
        self.assertAlmostEqual(no_grad_accum_loss, yes_grad_accum_loss, places=2)
538

539
    def check_saved_checkpoints_deepspeed(self, output_dir, freq, total, stage, dtype):
540
541
542
        # adapted from TrainerIntegrationCommon.check_saved_checkpoints

        file_list = [WEIGHTS_NAME, "training_args.bin", "trainer_state.json", "config.json"]
543
544
545
546
547
548
549
550

        if stage == ZERO2:
            ds_file_list = ["mp_rank_00_model_states.pt"]
        elif stage == ZERO3:
            ds_file_list = ["zero_pp_rank_0_mp_rank_00_model_states.pt"]
        else:
            raise ValueError(f"unknown stage {stage}")

551
552
        if dtype == "bf16":
            ds_file_list.append("bf16_zero_pp_rank_0_mp_rank_00_optim_states.pt")
553
554
555

        for step in range(freq, total, freq):
            checkpoint = os.path.join(output_dir, f"checkpoint-{step}")
556
            self.assertTrue(os.path.isdir(checkpoint), f"[{stage}] {checkpoint} dir is not found")
557
558
559

            # common files
            for filename in file_list:
560
561
                path = os.path.join(checkpoint, filename)
                self.assertTrue(os.path.isfile(path), f"[{stage}] {path} is not found")
562
563
564
565
566
567

            # ds files
            ds_path = os.path.join(checkpoint, f"global_step{step}")
            for filename in ds_file_list:
                # filename = os.path.join(path, filename)
                # print(filename)
568
569
                path = os.path.join(ds_path, filename)
                self.assertTrue(os.path.isfile(path), f"[{stage}] {path} is not found")
570

571
572
    @parameterized.expand(params, name_func=parameterized_custom_name_func)
    def test_save_checkpoints(self, stage, dtype):
573
574
        # adapted from  TrainerIntegrationTest.test_save_checkpoints

575
        freq = 5
576
        output_dir = self.get_auto_remove_tmp_dir()
577
        ds_config_dict = self.get_config_dict(stage)
578
579
580
        if dtype == FP16:
            ds_config_dict["fp16"]["initial_scale_power"] = 1  # force optimizer on the first step
        # XXX:
581
        if stage == ZERO3:
582
            ds_config_dict["zero_optimization"]["stage3_gather_16bit_weights_on_model_save"] = True
583
584
585

        # save checkpoints
        with mockenv_context(**self.dist_env_1_gpu):
586
            kwargs = dict(
587
588
589
590
                output_dir=output_dir,
                save_steps=freq,
                deepspeed=ds_config_dict,
            )
591
592
            kwargs[dtype] = True
            trainer = get_regression_trainer(**kwargs)
593
594
595
            trainer.train()

        total = int(self.n_epochs * 64 / self.batch_size)
596
        self.check_saved_checkpoints_deepspeed(output_dir, freq, total, stage, dtype)
597

598
599
    @parameterized.expand(params, name_func=parameterized_custom_name_func)
    def test_can_resume_training_errors(self, stage, dtype):
600
601
602
        with mockenv_context(**self.dist_env_1_gpu):
            ds_config_dict = self.get_config_dict(stage)
            output_dir = self.get_auto_remove_tmp_dir()
603
604
605
            kwargs = dict(output_dir=output_dir, deepspeed=ds_config_dict)
            kwargs[dtype] = True
            trainer = get_regression_trainer(**kwargs)
606
607
608
609
610
611
612
613

            # 1. fail to find any checkpoint - due a fresh output_dir
            with self.assertRaises(Exception) as context:
                trainer.train(resume_from_checkpoint=True)
            self.assertTrue(
                "No valid checkpoint found in output directory" in str(context.exception),
                f"got exception: {context.exception}",
            )
614

615
616
617
618
619
620
621
622
            # 2. fail to find a bogus checkpoint
            with self.assertRaises(Exception) as context:
                checkpoint = os.path.join(output_dir, "checkpoint-5")
                trainer.train(resume_from_checkpoint=f"{checkpoint}-bogus")
            self.assertTrue(
                "Can't find a valid checkpoint at" in str(context.exception), f"got exception: {context.exception}"
            )

623
624
    @parameterized.expand(params, name_func=parameterized_custom_name_func)
    def test_can_resume_training_normal(self, stage, dtype):
625
626
        # adapted from TrainerIntegrationTest.test_can_resume_training
        # test normal resume for each stage separately, error-handling is tested in a different test
627
        output_dir = self.get_auto_remove_tmp_dir("./xxx", after=False)
628
        ds_config_dict = self.get_config_dict(stage)
629
630
631
        if dtype == FP16:
            ds_config_dict["fp16"]["initial_scale_power"] = 1  # force optimizer on the first step
        # XXX:
632
        if stage == ZERO3:
633
            ds_config_dict["zero_optimization"]["stage3_gather_16bit_weights_on_model_save"] = True
634

635
636
        kwargs = dict(output_dir=output_dir, train_len=128, save_steps=5, learning_rate=0.1, deepspeed=ds_config_dict)
        kwargs[dtype] = True
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668

        with mockenv_context(**self.dist_env_1_gpu):
            trainer = get_regression_trainer(**kwargs)
            trainer.train()
            (a, b) = trainer.model.a.item(), trainer.model.b.item()
            state = dataclasses.asdict(trainer.state)

            checkpoint = os.path.join(output_dir, "checkpoint-5")

            # Reinitialize trainer
            trainer = get_regression_trainer(**kwargs)

            trainer.train(resume_from_checkpoint=checkpoint)
            (a1, b1) = trainer.model.a.item(), trainer.model.b.item()
            state1 = dataclasses.asdict(trainer.state)
            self.assertEqual(a, a1)
            self.assertEqual(b, b1)
            self.check_trainer_state_are_the_same(state, state1)

            # Now check with a later checkpoint that it also works when we span over one epoch
            checkpoint = os.path.join(output_dir, "checkpoint-15")

            # Reinitialize trainer and load model
            trainer = get_regression_trainer(**kwargs)

            trainer.train(resume_from_checkpoint=checkpoint)
            (a1, b1) = trainer.model.a.item(), trainer.model.b.item()
            state1 = dataclasses.asdict(trainer.state)
            self.assertEqual(a, a1)
            self.assertEqual(b, b1)
            self.check_trainer_state_are_the_same(state, state1)

669
670
671
672
673
            # Finally, should be able to resume with the same trainer/same deepspeed engine instance
            # XXX: but currently this not possible due DS bug: https://github.com/microsoft/DeepSpeed/issues/1612
            # trainer.train(resume_from_checkpoint=checkpoint)
            # a workaround needs to be used that re-creates the deepspeed engine

674
675
    @parameterized.expand(params, name_func=parameterized_custom_name_func)
    def test_load_state_dict_from_zero_checkpoint(self, stage, dtype):
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
        # test that we can load fp32 weights directly from the zero checkpoint into the current model

        output_dir = self.get_auto_remove_tmp_dir()  # "./xxx", after=False, before=False)

        ds_config_dict = self.get_config_dict(stage)

        kwargs = dict(
            output_dir=output_dir,
            train_len=4,
            per_device_train_batch_size=4,
            num_train_epochs=1,
            save_strategy="steps",
            save_steps=1,
            learning_rate=0.1,
            deepspeed=ds_config_dict,
        )
692
        kwargs[dtype] = True
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708

        with mockenv_context(**self.dist_env_1_gpu):
            trainer = get_regression_trainer(**kwargs)
            trainer.train()
            (a, b) = trainer.model.a.item(), trainer.model.b.item()
            state = dataclasses.asdict(trainer.state)

            checkpoint_dir = get_last_checkpoint(output_dir)
            model = load_state_dict_from_zero_checkpoint(trainer.model, checkpoint_dir)

            (a1, b1) = model.a.item(), model.b.item()
            state1 = dataclasses.asdict(trainer.state)
            self.assertEqual(a, a1)
            self.assertEqual(b, b1)
            self.check_trainer_state_are_the_same(state, state1)

709
710
711
712
    def test_config_object(self):
        # test that we can switch from zero2 to zero3 in the same process for example
        # test is_zero, etc.
        output_dir = self.get_auto_remove_tmp_dir()
713
        kwargs = dict(output_dir=output_dir, train_len=8, fp16=True)
714

715
716
        ds_config_zero3_dict = self.get_config_dict(ZERO3)
        ds_config_zero2_dict = self.get_config_dict(ZERO2)
717

718
        with mockenv_context(**self.dist_env_1_gpu):
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
            trainer = get_regression_trainer(deepspeed=ds_config_zero3_dict, **kwargs)
            self.assertTrue(is_deepspeed_zero3_enabled())

            # test we can repeat that and with train this time
            trainer = get_regression_trainer(deepspeed=ds_config_zero3_dict, **kwargs)
            trainer.train()
            self.assertTrue(is_deepspeed_zero3_enabled())

            # test zero3 is disabled
            trainer = get_regression_trainer(deepspeed=ds_config_zero2_dict, **kwargs)
            self.assertFalse(is_deepspeed_zero3_enabled())

            # check config obj
            config = deepspeed_config()
            self.assertTrue(bool(config), "Deepspeed config should be accessible")

            del trainer
            # now weakref should gc the global and we shouldn't get anything here
            config = deepspeed_config()
            self.assertFalse(is_deepspeed_zero3_enabled())
            self.assertFalse(bool(config), "Deepspeed config should not be accessible")

741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
    @parameterized.expand(params, name_func=parameterized_custom_name_func)
    def test_load_best_model(self, stage, dtype):
        # Test that forced deepspeed reinit doesn't break the model. the forced re-init after
        # loading the best model in Trainer is there to workaround this bug in Deepspeed
        # https://github.com/microsoft/DeepSpeed/issues/1612
        #
        # The test is derived from a repro script submitted in this Issue:
        # https://github.com/huggingface/transformers/issues/17114
        #
        # One additional feature of this test is that we use a non-AdamW optimizer to test that
        # deepspeed doesn't fallback to AdamW, which would prevent the optimizer states from loading
        # correctly

        from transformers import T5ForConditionalGeneration, T5Tokenizer, Trainer  # noqa

        output_dir = self.get_auto_remove_tmp_dir()  # "./xxx", after=False, before=False)

        ds_config_dict = self.get_config_dict(stage)
        del ds_config_dict["optimizer"]  # will use HF Trainer optimizer
        del ds_config_dict["scheduler"]  # will use HF Trainer scheduler
        # must use this setting to get the reload path exercised
        ds_config_dict["zero_optimization"]["stage3_gather_16bit_weights_on_model_save"] = True

764
        with mockenv_context(**self.dist_env_1_gpu):
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
            args_dict = {
                "per_gpu_train_batch_size": 1,
                "per_gpu_eval_batch_size": 1,
                "gradient_accumulation_steps": 1,
                "learning_rate": 1e-4,
                "num_train_epochs": 1,
                "do_train": True,
                "do_eval": True,
                "optim": "adafactor",
                "evaluation_strategy": "steps",
                "eval_steps": 1,
                "save_strategy": "steps",
                "save_steps": 1,
                "load_best_model_at_end": True,
                "max_steps": 1,
                "deepspeed": ds_config_dict,
            }

            training_args = TrainingArguments(output_dir, **args_dict)
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
            tokenizer = T5Tokenizer.from_pretrained(T5_TINY)
            model = T5ForConditionalGeneration.from_pretrained(T5_TINY)

            def _add_eos_to_examples(example):
                example["input_text"] = f"question: {example['question']}  context: {example['context']}"
                example["target_text"] = example["answers"]["text"][0] if len(example["answers"]["text"]) > 0 else ""
                return example

            def _convert_to_features(example_batch):
                input_encodings = tokenizer.batch_encode_plus(
                    example_batch["input_text"], pad_to_max_length=True, max_length=512, truncation=True
                )
                target_encodings = tokenizer.batch_encode_plus(
                    example_batch["target_text"], pad_to_max_length=True, max_length=16, truncation=True
                )

                encodings = {
                    "input_ids": input_encodings["input_ids"],
                    "attention_mask": input_encodings["attention_mask"],
                    "labels": target_encodings["input_ids"],
                }

                return encodings

            def get_dataset():
                data_file = str(self.tests_dir / "fixtures/tests_samples/SQUAD/sample.json")
                data_files = dict(train=data_file, validation=data_file)
                raw_datasets = datasets.load_dataset("json", data_files=data_files, field="data")
                train_dataset = raw_datasets["train"].map(_add_eos_to_examples).map(_convert_to_features, batched=True)
                valid_dataset = deepcopy(train_dataset)
                return train_dataset, valid_dataset

            train_dataset, eval_dataset = get_dataset()

818
819
820
821
822
823
824
825
826
827
            trainer = Trainer(
                model=model,
                tokenizer=tokenizer,
                args=training_args,
                train_dataset=train_dataset,
                eval_dataset=eval_dataset,
            )
            trainer.train()  # crash 1 was here
            trainer.evaluate()  # crash 2 was here

828
829
830
831

@slow
@require_deepspeed
@require_torch_gpu
832
class TestDeepSpeedWithLauncher(TestCasePlus):
Patrick von Platen's avatar
Patrick von Platen committed
833
    """This class is for testing via an external script - can do multiple gpus"""
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849

    # Tests to devise #
    #
    # 1. predict_with_generate on multigpu - need to figure out how to give input sequences so that
    # the 2 gpus will generate prediction sequences that aren't of the same length - this is because
    # we had to code a special feature to sync the gpus when the predicted sequences aren't of the
    # same length. In general this will tested as a side-effect through a variety of other tests -
    # it'll simply hang trying to synchronize with other gpus if this problem is encountered. So as
    # long as we have a few full tests running on zero3 + predict_with_generate this should be
    # mostly covered.
    #
    # but there are 5 variations on beam search in `generate`- with identical code branched with `if
    # synced_gpus`
    #
    # 2. most tests should probably be run on both: zero2 and zero3 configs
    #
850

851
    @require_torch_multi_gpu
852
853
854
    @parameterized.expand(params, name_func=parameterized_custom_name_func)
    def test_basic_distributed(self, stage, dtype):
        self.run_and_check(stage=stage, dtype=dtype, distributed=True)
855

856
857
    def test_do_eval_no_train(self):
        # testing only zero3 since zero2 makes no sense with inference
858
        self.run_and_check(
859
            stage=ZERO3,
860
            dtype=FP16,
861
862
            eval_steps=1,
            distributed=False,
863
864
            do_train=False,
            do_eval=True,
865
        )
866

867
868
    @parameterized.expand(params, name_func=parameterized_custom_name_func)
    def test_fp32_non_distributed(self, stage, dtype):
869
870
871
872
        # real model needs too much GPU memory under stage2+fp32, so using tiny random model here -
        # therefore no quality checks, just basic completion checks are done
        self.run_and_check(
            stage=stage,
873
            dtype=dtype,
874
875
876
877
878
            model_name=T5_TINY,
            distributed=False,
            do_train=True,
            do_eval=True,
            quality_checks=False,
879
            fp32=True,
880
881
882
        )

    @require_torch_multi_gpu
883
884
    @parameterized.expand(params, name_func=parameterized_custom_name_func)
    def test_fp32_distributed(self, stage, dtype):
885
886
887
888
        # real model needs too much GPU memory under stage2+fp32, so using tiny random model here -
        # therefore no quality checks, just basic completion checks are done
        self.run_and_check(
            stage=stage,
889
            dtype=dtype,
890
891
892
893
894
            model_name=T5_TINY,
            distributed=True,
            do_train=True,
            do_eval=True,
            quality_checks=False,
895
            fp32=True,
896
897
        )

898
899
    @parameterized.expand(params, name_func=parameterized_custom_name_func)
    def test_resume_train_not_from_ds_checkpoint(self, stage, dtype):
900
901
902
903
904
        # do normal training and then resume not from the deepspeed checkpoint but explicitly from
        # the saved model dir

        do_train = True
        do_eval = False
905
        kwargs = dict(stage=stage, dtype=dtype, eval_steps=1, distributed=True, do_train=do_train, do_eval=do_eval)
906
907
908
909
910
911
912
913
914
915

        # 1. normal training
        output_dir = self.run_and_check(**kwargs)

        # 2. now resume explicitly from the saved weights, by passing --model_name_or_path output_dir
        # - i.e. the same path the model was saved to in step 1
        output_dir = self.run_trainer(**kwargs, model_name=output_dir)

        self.do_checks(output_dir, do_train=do_train, do_eval=do_eval)

916
    @require_torch_multi_gpu
917
    @parameterized.expand(["bf16", "fp16", "fp32"])
918
    def test_inference(self, dtype):
919
        if dtype == "bf16" and not is_torch_bf16_gpu_available():
920
921
            self.skipTest("test requires bfloat16 hardware support")

922
923
        # this is just inference, so no optimizer should be loaded
        # it only works for z3 (makes no sense with z1-z2)
924
        fp32 = True if dtype == "fp32" else False
925
926
        self.run_and_check(
            stage=ZERO3,
927
            dtype=FP16,
928
929
930
931
932
            model_name=T5_TINY,
            distributed=True,
            do_train=False,
            do_eval=True,
            quality_checks=False,
933
            fp32=fp32,
934
935
        )

936
    def do_checks(self, output_dir, do_train=True, do_eval=True, quality_checks=True):
937
938
939
        if do_train:
            train_metrics = load_json(os.path.join(output_dir, "train_results.json"))
            self.assertIn("train_samples_per_second", train_metrics)
940
941
            if quality_checks:
                self.assertGreater(train_metrics["train_samples_per_second"], 0.5)
942
943
944
945

        if do_eval:
            eval_metrics = load_json(os.path.join(output_dir, "eval_results.json"))
            self.assertIn("eval_bleu", eval_metrics)
946
947
            if quality_checks:
                self.assertGreater(eval_metrics["eval_bleu"], 1)
948
949

    # XXX: need to do better validation beyond just that the run was successful
950
951
952
    def run_and_check(
        self,
        stage,
953
        dtype,
954
955
956
957
958
959
        model_name: str = T5_SMALL,
        eval_steps: int = 10,
        distributed: bool = True,
        do_train: bool = True,
        do_eval: bool = True,
        quality_checks: bool = True,
960
        fp32: bool = False,
961
962
        extra_args_str: str = None,
        remove_args_str: str = None,
963
964
    ):
        # we are doing quality testing so using a small real model
965
        output_dir = self.run_trainer(
966
            stage=stage,
967
            dtype=dtype,
968
            model_name=model_name,
969
            eval_steps=eval_steps,
970
            num_train_epochs=1,
971
972
            do_train=do_train,
            do_eval=do_eval,
973
            distributed=distributed,
974
            fp32=fp32,
975
976
977
            extra_args_str=extra_args_str,
            remove_args_str=remove_args_str,
        )
978

979
        self.do_checks(output_dir, do_train=do_train, do_eval=do_eval, quality_checks=quality_checks)
980
981

        return output_dir
982
983
984

    def run_trainer(
        self,
985
        stage: str,
986
        dtype: str,
987
        model_name: str,
988
989
990
991
        eval_steps: int = 10,
        num_train_epochs: int = 1,
        do_train: bool = False,
        do_eval: bool = True,
992
        distributed: bool = True,
993
        fp32: bool = False,
994
995
996
        extra_args_str: str = None,
        remove_args_str: str = None,
    ):
997
        max_len = 32
Sylvain Gugger's avatar
Sylvain Gugger committed
998
        data_dir = self.test_file_dir / "../fixtures/tests_samples/wmt_en_ro"
999
1000
1001
        output_dir = self.get_auto_remove_tmp_dir()
        args = f"""
            --model_name_or_path {model_name}
1002
1003
            --train_file {data_dir}/train.json
            --validation_file {data_dir}/val.json
1004
1005
1006
1007
1008
1009
1010
            --output_dir {output_dir}
            --overwrite_output_dir
            --max_source_length {max_len}
            --max_target_length {max_len}
            --val_max_target_length {max_len}
            --warmup_steps 8
            --predict_with_generate
1011
1012
            --save_steps 0
            --eval_steps {eval_steps}
1013
1014
            --group_by_length
            --label_smoothing_factor 0.1
1015
1016
            --source_lang en
            --target_lang ro
1017
            --report_to none
1018
        """.split()
1019
1020
        args.extend(["--source_prefix", '"translate English to Romanian: "'])

1021
1022
        if not fp32:
            args.extend([f"--{dtype}"])
1023

1024
1025
1026
1027
1028
1029
1030
        actions = 0
        if do_train:
            actions += 1
            args.extend(
                f"""
            --do_train
            --num_train_epochs {str(num_train_epochs)}
1031
            --max_train_samples 16
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
            --per_device_train_batch_size 2
            --learning_rate 3e-3
            """.split()
            )

        if do_eval:
            actions += 1
            args.extend(
                """
            --do_eval
1042
            --max_eval_samples 16
1043
1044
1045
1046
1047
            --per_device_eval_batch_size 2
            """.split()
            )

        assert actions > 0, "need at least do_train or do_eval for the test to run"
1048
1049
1050
1051

        if extra_args_str is not None:
            args.extend(extra_args_str.split())

1052
        # currently only works for bool args
1053
1054
1055
1056
        if remove_args_str is not None:
            remove_args = remove_args_str.split()
            args = [x for x in args if x not in remove_args]

1057
        ds_args = f"--deepspeed {self.test_file_dir_str}/ds_config_{stage}.json".split()
Sylvain Gugger's avatar
Sylvain Gugger committed
1058
        script = [f"{self.examples_dir_str}/pytorch/translation/run_translation.py"]
1059
        launcher = get_launcher(distributed)
1060
1061

        cmd = launcher + script + args + ds_args
1062
        # keep for quick debug
1063
1064
1065
1066
1067
        # print(" ".join([f"\nPYTHONPATH={self.src_dir_str}"] +cmd)); die
        execute_subprocess_async(cmd, env=self.get_env())

        return output_dir

1068
1069
    @parameterized.expand(params, name_func=parameterized_custom_name_func)
    def test_clm(self, stage, dtype):
1070
1071
1072
1073
1074
1075
        # this test exercises model.resize_token_embeddings() which requires param gathering outside
        # of forward - it's not used by `run_translation.py`, but it is in `run_clm.py`

        data_dir = self.tests_dir / "fixtures"
        output_dir = self.get_auto_remove_tmp_dir()
        args = f"""
1076
            --model_name_or_path {GPT2_TINY}
1077
1078
1079
1080
1081
1082
            --train_file {data_dir}/sample_text.txt
            --validation_file {data_dir}/sample_text.txt
            --output_dir {output_dir}
            --overwrite_output_dir
            --do_train
            --do_eval
1083
1084
1085
1086
            --max_train_samples 16
            --max_eval_samples 16
            --per_device_train_batch_size 2
            --per_device_eval_batch_size 2
1087
1088
            --num_train_epochs 1
            --warmup_steps 8
1089
            --block_size 64
1090
            --report_to none
1091
1092
            """.split()

1093
1094
        args.extend([f"--{dtype}"])

1095
        ds_args = f"--deepspeed {self.test_file_dir_str}/ds_config_{stage}.json".split()
Sylvain Gugger's avatar
Sylvain Gugger committed
1096
        script = [f"{self.examples_dir_str}/pytorch/language-modeling/run_clm.py"]
1097
        launcher = get_launcher(distributed=True)
1098
1099
1100
1101

        cmd = launcher + script + args + ds_args
        # keep for quick debug
        # print(" ".join([f"\nPYTHONPATH={self.src_dir_str}"] +cmd)); die
1102
1103
        execute_subprocess_async(cmd, env=self.get_env())

1104
    def test_clm_from_config_zero3_fp16(self):
1105
1106
1107
1108
1109
1110
        # this test exercises AutoModel.from_config(config) - to ensure zero.Init is called

        data_dir = self.tests_dir / "fixtures"
        output_dir = self.get_auto_remove_tmp_dir()
        args = f"""
            --model_type gpt2
1111
            --tokenizer_name {GPT2_TINY}
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
            --train_file {data_dir}/sample_text.txt
            --validation_file {data_dir}/sample_text.txt
            --output_dir {output_dir}
            --overwrite_output_dir
            --do_train
            --max_train_samples 4
            --per_device_train_batch_size 2
            --num_train_epochs 1
            --warmup_steps 8
            --block_size 8
            --fp16
            --report_to none
            """.split()

        ds_args = f"--deepspeed {self.test_file_dir_str}/ds_config_zero3.json".split()
        script = [f"{self.examples_dir_str}/pytorch/language-modeling/run_clm.py"]
1128
        launcher = get_launcher(distributed=True)
1129
1130
1131
1132
1133
1134

        cmd = launcher + script + args + ds_args
        # keep for quick debug
        # print(" ".join([f"\nPYTHONPATH={self.src_dir_str}"] +cmd)); die
        with CaptureStderr() as cs:
            execute_subprocess_async(cmd, env=self.get_env())
1135
        self.assertIn("Detected DeepSpeed ZeRO-3", cs.err)