test_checkpointing.py 23.4 KB
Newer Older
1
import torch
2

3
4
import torch.distributed as dist

5
import deepspeed
6
7
from deepspeed.runtime.zero.stage2 import FP16_DeepSpeedZeroOptimizer
from deepspeed.runtime.zero.stage1 import FP16_DeepSpeedZeroOptimizer_Stage1
8

9
10
from deepspeed.runtime.fp16.fused_optimizer import FP16_Optimizer
from deepspeed.runtime.fp16.unfused_optimizer import FP16_UnfusedOptimizer
11

12
13
14
from deepspeed.runtime.pipe.topology import *
PipeTopo = PipeDataParallelTopology

15
16
17
18
import argparse
import pytest
import json
import os
Jeff Rasley's avatar
Jeff Rasley committed
19
import numbers
20
from common import distributed_test
21
from simple_model import *
22
23


24
25
26
27
28
29
30
31
32
def compare_deepspeed_states(saved_model, loaded_model):
    # These are compared in more depth in other places
    assert hasattr(loaded_model, 'module')

    assert saved_model.csr_tensor_module_names == loaded_model.csr_tensor_module_names
    assert saved_model.skipped_steps == loaded_model.skipped_steps
    assert saved_model.global_steps == loaded_model.global_steps


33
def compare_model_states(saved_model, loaded_model, compare_optimizer=True):
34
35
    compare_deepspeed_states(saved_model, loaded_model)

36
    for p0, p1 in zip(saved_model.module.parameters(), loaded_model.module.parameters()):
Jeff Rasley's avatar
Jeff Rasley committed
37
        assert torch.allclose(p0, p1, atol=1e-07), f"FP16 model state {p0} is not equal to {p1}"
38

39
40
41
    if not compare_optimizer:
        return

42
43
    if isinstance(saved_model.optimizer, FP16_DeepSpeedZeroOptimizer):
        for p0, p1 in zip(saved_model.optimizer.single_partition_of_fp32_groups, loaded_model.optimizer.single_partition_of_fp32_groups):
Jeff Rasley's avatar
Jeff Rasley committed
44
            assert torch.allclose(p0, p1, atol=1e-07), f"Fp32 model states {p0} is not equal to {p1}"
45

Jeff Rasley's avatar
Jeff Rasley committed
46
47
48
    elif isinstance(saved_model.optimizer, FP16_DeepSpeedZeroOptimizer_Stage1):
        for partition0, partition1 in zip(saved_model.optimizer.local_sub_partitions_of_fp32_groups, loaded_model.optimizer.local_sub_partitions_of_fp32_groups):
            for p0, p1 in zip(partition0, partition1):
Jeff Rasley's avatar
Jeff Rasley committed
49
                assert torch.allclose(p0, p1, atol=1e-07), f"Fp32 model states {p0} is not equal to {p1}"
Jeff Rasley's avatar
Jeff Rasley committed
50

51
52
    elif isinstance(saved_model.optimizer, FP16_Optimizer):
        for p0, p1 in zip(saved_model.optimizer.fp32_groups_flat, loaded_model.optimizer.fp32_groups_flat):
Jeff Rasley's avatar
Jeff Rasley committed
53
            assert torch.allclose(p0, p1, atol=1e-07), f"FP32 model states {p0} is not equal to {p1}"
54
55
56
57

    elif isinstance(saved_model.optimizer, FP16_UnfusedOptimizer):
        for params0, params1 in zip(saved_model.optimizer.fp32_groups, loaded_model.optimizer.fp32_groups):
            for p0, p1 in zip(params0, params1):
Jeff Rasley's avatar
Jeff Rasley committed
58
                assert torch.allclose(p0, p1, atol=1e-07), f"FP32 model states {p0} is not equal to {p1}"
59
60
    elif isinstance(saved_model.optimizer, torch.optim.Optimizer):
        pass
61
    else:
62
63
        assert False, f'Unexpected Optimizer Type: {saved_model.optimizer}'

64

65
66
67
def compare_optimizer_states(saved_model, loaded_model, hidden_dim, fp16=True):
    saved_optimizer = saved_model.optimizer.optimizer if fp16 else saved_model.optimizer
    loaded_optimizer = loaded_model.optimizer.optimizer if fp16 else loaded_model.optimizer
68

69
70
    for state0, state1 in zip(saved_optimizer.state.values(),
                              loaded_optimizer.state.values()):
71
72
73
74
75
76
77
        for s0, s1 in zip(state0.values(), state1.values()):
            if isinstance(s0, torch.Tensor) and isinstance(s1, torch.Tensor):
                assert torch.equal(s0, s1)
            else:
                assert s0 == s1


Jeff Rasley's avatar
Jeff Rasley committed
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
def compare_lr_scheduler_states(saved_model, loaded_model):
    assert hasattr(saved_model, 'lr_scheduler')
    assert hasattr(loaded_model, 'lr_scheduler')

    saved_scheduler = saved_model.lr_scheduler
    loaded_scheduler = loaded_model.lr_scheduler

    assert hasattr(saved_scheduler, 'state_dict')
    assert hasattr(loaded_scheduler, 'state_dict')

    saved_sd = saved_scheduler.state_dict()
    loaded_sd = loaded_scheduler.state_dict()

    print(f"saved_sd = {saved_sd}")
    print(f"loaded_sd = {loaded_sd}")

    assert saved_sd.keys() == loaded_sd.keys()

    for state0, state1 in zip(saved_sd.values(), loaded_sd.values()):
        if isinstance(state0, numbers.Number) and isinstance(state1, numbers.Number):
            assert state0 == state1


def checkpoint_correctness_verification(args,
102
103
                                        model,
                                        hidden_dim,
Jeff Rasley's avatar
Jeff Rasley committed
104
105
                                        tmpdir,
                                        load_optimizer_states=False,
106
                                        load_lr_scheduler_states=False,
107
108
                                        fp16=True,
                                        train_batch=False):
109
    dtype = torch.half if fp16 else torch.float32
Jeff Rasley's avatar
Jeff Rasley committed
110
111
112
    ds_model, _, _, _ = deepspeed.initialize(args=args,
                                             model=model,
                                             model_parameters=model.parameters())
113
114
115
    data_loader = random_dataloader(model=ds_model,
                                    total_samples=50,
                                    hidden_dim=hidden_dim,
116
117
                                    device=ds_model.device,
                                    dtype=dtype)
118
119
120
121
122
123
124
125
126
127
128

    if train_batch:
        ds_model.set_dataloader(data_loader)
        for n, batch in enumerate(data_loader):
            loss = ds_model.train_batch()
    else:
        for n, batch in enumerate(data_loader):
            loss = ds_model(batch[0], batch[1])
            print(loss)
            ds_model.backward(loss)
            ds_model.step()
129
130
131

    trained_model = ds_model

Jeff Rasley's avatar
Jeff Rasley committed
132
    save_folder = os.path.join(tmpdir, 'saved_checkpoint')
133
134
135
136
    save_tag = '1'

    trained_model.save_checkpoint(save_folder, save_tag)

Jeff Rasley's avatar
Jeff Rasley committed
137
138
139
    loaded_model, _, _, _ = deepspeed.initialize(args=args,
                                                 model=model,
                                                 model_parameters=model.parameters())
140
141
142

    loaded_model.load_checkpoint(save_folder,
                                 save_tag,
Jeff Rasley's avatar
Jeff Rasley committed
143
144
                                 load_optimizer_states=load_optimizer_states,
                                 load_lr_scheduler_states=load_lr_scheduler_states)
145

Jeff Rasley's avatar
Jeff Rasley committed
146
    compare_model_states(trained_model, loaded_model)
147

148
    if load_optimizer_states:
149
        compare_optimizer_states(trained_model, loaded_model, hidden_dim, fp16)
Jeff Rasley's avatar
Jeff Rasley committed
150
151
152

    if load_lr_scheduler_states:
        compare_lr_scheduler_states(trained_model, loaded_model)
153
154


155
156
@pytest.mark.skipif(not deepspeed.ops.__installed_ops__['lamb'],
                    reason="lamb is not installed")
157
158
159
160
161
162
163
def test_checkpoint_unfused_optimizer(tmpdir):
    config_dict = {
        "train_batch_size": 2,
        "steps_per_print": 1,
        "optimizer": {
            "type": "Lamb",
            "params": {
164
                "lr": 0.00015
165
166
            }
        },
167
        "gradient_clipping": 1.0,
168
169
        "fp16": {
            "enabled": True
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
        },
        "scheduler": {
            "type": "OneCycle",
            "params": {
                "cycle_first_step_size": 1000,
                "cycle_first_stair_count": 500,
                "cycle_second_step_size": 1000,
                "cycle_second_stair_count": 500,
                "decay_step_size": 1000,
                "cycle_min_lr": 0.0001,
                "cycle_max_lr": 0.0010,
                "decay_lr_rate": 0.001,
                "cycle_min_mom": 0.85,
                "cycle_max_mom": 0.99,
                "decay_mom_rate": 0.0
            }
186
187
188
189
190
191
192
193
194
195
196
197
198
        }
    }

    args = args_from_dict(tmpdir, config_dict)
    hidden_dim = 10

    model = SimpleModel(hidden_dim, empty_grad=False)

    @distributed_test(world_size=[2])
    def _test_checkpoint_unfused_optimizer(args,
                                           model,
                                           hidden_dim,
                                           load_optimizer_states):
Jeff Rasley's avatar
Jeff Rasley committed
199
        checkpoint_correctness_verification(args,
200
201
                                            model,
                                            hidden_dim,
Jeff Rasley's avatar
Jeff Rasley committed
202
                                            tmpdir,
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
                                            load_optimizer_states=load_optimizer_states)

    _test_checkpoint_unfused_optimizer(args=args,
                                       model=model,
                                       hidden_dim=hidden_dim,
                                       load_optimizer_states=True)
    _test_checkpoint_unfused_optimizer(args=args,
                                       model=model,
                                       hidden_dim=hidden_dim,
                                       load_optimizer_states=False)


def test_checkpoint_fused_optimizer(tmpdir):
    config_dict = {
        "train_batch_size": 2,
        "steps_per_print": 1,
        "optimizer": {
            "type": "Adam",
            "params": {
                "lr": 0.00015,
                "betas": [0.8,
                          0.999],
                "eps": 1e-8,
                "weight_decay": 3e-7
            }
        },
        "fp16": {
            "enabled": True
        }
    }

    args = args_from_dict(tmpdir, config_dict)
    hidden_dim = 10

    model = SimpleModel(hidden_dim, empty_grad=False)

    @distributed_test(world_size=[2])
    def _test_checkpoint_fused_optimizer(args, model, hidden_dim, load_optimizer_states):
Jeff Rasley's avatar
Jeff Rasley committed
241
        checkpoint_correctness_verification(args,
242
243
                                            model,
                                            hidden_dim,
Jeff Rasley's avatar
Jeff Rasley committed
244
                                            tmpdir,
245
246
247
248
249
250
251
252
253
254
255
256
                                            load_optimizer_states=load_optimizer_states)

    _test_checkpoint_fused_optimizer(args=args,
                                     model=model,
                                     hidden_dim=hidden_dim,
                                     load_optimizer_states=True)
    _test_checkpoint_fused_optimizer(args=args,
                                     model=model,
                                     hidden_dim=hidden_dim,
                                     load_optimizer_states=False)


Jeff Rasley's avatar
Jeff Rasley committed
257
258
259
260
261
262
263
264
265
266
267
268
269
@pytest.mark.parametrize('zero_stage, use_cpu_offload, adam_optimizer',
                         [
                             (1,
                              False,
                              'Adam'),
                             (2,
                              False,
                              'Adam'),
                             (2,
                              True,
                              'deepspeed_adam'),
                         ])
def test_checkpoint_zero_optimizer(tmpdir, zero_stage, use_cpu_offload, adam_optimizer):
270
271
272
    if use_cpu_offload and not deepspeed.ops.__installed_ops__['cpu-adam']:
        pytest.skip("cpu-adam is not installed")

273
274
275
276
    config_dict = {
        "train_batch_size": 2,
        "steps_per_print": 1,
        "optimizer": {
Jeff Rasley's avatar
Jeff Rasley committed
277
            "type": adam_optimizer,
278
279
280
281
282
283
284
285
286
287
288
            "params": {
                "lr": 0.00015,
                "betas": [0.8,
                          0.999],
                "eps": 1e-8,
                "weight_decay": 3e-7
            }
        },
        "fp16": {
            "enabled": True
        },
Jeff Rasley's avatar
Jeff Rasley committed
289
        "zero_optimization": {
Jeff Rasley's avatar
Jeff Rasley committed
290
291
292
            "stage": zero_stage,
            "cpu_offload": use_cpu_offload
        }
293
294
295
296
297
298
299
300
    }
    args = args_from_dict(tmpdir, config_dict)
    hidden_dim = 10

    model = SimpleModel(hidden_dim, empty_grad=False)

    @distributed_test(world_size=[2])
    def _test_checkpoint_zero_optimizer(args, model, hidden_dim, load_optimizer_states):
Jeff Rasley's avatar
Jeff Rasley committed
301
        checkpoint_correctness_verification(args,
302
303
                                            model,
                                            hidden_dim,
Jeff Rasley's avatar
Jeff Rasley committed
304
                                            tmpdir,
305
306
307
308
309
310
                                            load_optimizer_states=load_optimizer_states)

    _test_checkpoint_zero_optimizer(args=args,
                                    model=model,
                                    hidden_dim=hidden_dim,
                                    load_optimizer_states=True)
Jeff Rasley's avatar
Jeff Rasley committed
311
312


Jeff Rasley's avatar
Jeff Rasley committed
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
@pytest.mark.parametrize('zero_stage, use_cpu_offload, adam_optimizer',
                         [
                             (1,
                              False,
                              "Adam"),
                             (2,
                              False,
                              "Adam"),
                             (2,
                              True,
                              'deepspeed_adam'),
                         ])
def test_checkpoint_zero_no_optimizer(tmpdir,
                                      zero_stage,
                                      use_cpu_offload,
                                      adam_optimizer):
329
330
331
    if use_cpu_offload and not deepspeed.ops.__installed_ops__['cpu-adam']:
        pytest.skip("cpu-adam is not installed")

Jeff Rasley's avatar
Jeff Rasley committed
332
333
334
335
    config_dict = {
        "train_batch_size": 2,
        "steps_per_print": 1,
        "optimizer": {
Jeff Rasley's avatar
Jeff Rasley committed
336
            "type": adam_optimizer,
Jeff Rasley's avatar
Jeff Rasley committed
337
338
339
340
341
342
343
344
345
346
347
348
            "params": {
                "lr": 0.00015,
                "betas": [0.8,
                          0.999],
                "eps": 1e-8,
                "weight_decay": 3e-7
            }
        },
        "fp16": {
            "enabled": True
        },
        "zero_optimization": {
Jeff Rasley's avatar
Jeff Rasley committed
349
350
351
            "stage": zero_stage,
            "cpu_offload": use_cpu_offload
        }
Jeff Rasley's avatar
Jeff Rasley committed
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
    }
    args = args_from_dict(tmpdir, config_dict)
    hidden_dim = 10

    model = SimpleModel(hidden_dim, empty_grad=False)

    @distributed_test(world_size=[2])
    def _test_checkpoint_zero_no_optimizer(args,
                                           model,
                                           hidden_dim,
                                           load_optimizer_states):
        checkpoint_correctness_verification(args,
                                            model,
                                            hidden_dim,
                                            tmpdir,
                                            load_optimizer_states=load_optimizer_states)

    _test_checkpoint_zero_no_optimizer(args=args,
                                       model=model,
                                       hidden_dim=hidden_dim,
                                       load_optimizer_states=False)


Jeff Rasley's avatar
Jeff Rasley committed
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
@pytest.mark.parametrize('zero_stage, use_cpu_offload, adam_optimizer',
                         [
                             (0,
                              False,
                              'Adam'),
                             (1,
                              False,
                              'Adam'),
                             (2,
                              False,
                              'Adam'),
                             (2,
                              True,
                              'deepspeed_adam'),
                         ])
def test_checkpoint_lr_scheduler(tmpdir, zero_stage, use_cpu_offload, adam_optimizer):
391
392
393
    if use_cpu_offload and not deepspeed.ops.__installed_ops__['cpu-adam']:
        pytest.skip("cpu-adam is not installed")

Jeff Rasley's avatar
Jeff Rasley committed
394
395
396
397
    config_dict = {
        "train_batch_size": 2,
        "steps_per_print": 1,
        "optimizer": {
Jeff Rasley's avatar
Jeff Rasley committed
398
            "type": adam_optimizer,
Jeff Rasley's avatar
Jeff Rasley committed
399
400
401
402
403
404
405
406
407
408
409
410
            "params": {
                "lr": 0.00015,
                "betas": [0.8,
                          0.999],
                "eps": 1e-8,
                "weight_decay": 3e-7
            }
        },
        "fp16": {
            "enabled": True
        },
        "zero_optimization": {
Jeff Rasley's avatar
Jeff Rasley committed
411
412
            "stage": zero_stage,
            "cpu_offload": use_cpu_offload
Jeff Rasley's avatar
Jeff Rasley committed
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
        },
        "scheduler": {
            "type": "WarmupLR",
            "params": {
                "warmup_min_lr": 0,
                "warmup_max_lr": 0.001,
                "warmup_num_steps": 1000
            }
        }
    }
    args = args_from_dict(tmpdir, config_dict)
    hidden_dim = 10

    model = SimpleModel(hidden_dim, empty_grad=False)

    @distributed_test(world_size=[2])
    def _test_checkpoint_lr_scheduler(args,
                                      model,
                                      hidden_dim,
                                      load_optimizer_states,
                                      load_lr_scheduler_states):
        checkpoint_correctness_verification(
            args,
            model,
            hidden_dim,
            tmpdir,
            load_optimizer_states=load_optimizer_states,
            load_lr_scheduler_states=load_lr_scheduler_states)

    _test_checkpoint_lr_scheduler(args=args,
                                  model=model,
                                  hidden_dim=hidden_dim,
                                  load_optimizer_states=False,
                                  load_lr_scheduler_states=True)


Jeff Rasley's avatar
Jeff Rasley committed
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
@pytest.mark.parametrize('zero_stage, use_cpu_offload, adam_optimizer',
                         [
                             (0,
                              False,
                              'Adam'),
                             (1,
                              False,
                              'Adam'),
                             (2,
                              False,
                              'Adam'),
                             (2,
                              True,
                              'deepspeed_adam'),
                         ])
def test_checkpoint_no_lr_scheduler(tmpdir, zero_stage, use_cpu_offload, adam_optimizer):
465
466
467
    if use_cpu_offload and not deepspeed.ops.__installed_ops__['cpu-adam']:
        pytest.skip("cpu-adam is not installed")

Jeff Rasley's avatar
Jeff Rasley committed
468
469
470
471
    config_dict = {
        "train_batch_size": 2,
        "steps_per_print": 1,
        "optimizer": {
Jeff Rasley's avatar
Jeff Rasley committed
472
            "type": adam_optimizer,
Jeff Rasley's avatar
Jeff Rasley committed
473
474
475
476
477
478
479
480
            "params": {
                "lr": 1e-5
            }
        },
        "fp16": {
            "enabled": True
        },
        "zero_optimization": {
Jeff Rasley's avatar
Jeff Rasley committed
481
482
            "stage": zero_stage,
            "cpu_offload": use_cpu_offload
Jeff Rasley's avatar
Jeff Rasley committed
483
484
485
486
487
488
489
490
        },
        "scheduler": {
            "type": "WarmupLR",
            "params": {
                "warmup_min_lr": 0,
                "warmup_max_lr": 0.001,
                "warmup_num_steps": 1000
            }
Jeff Rasley's avatar
Jeff Rasley committed
491
        },
Jeff Rasley's avatar
Jeff Rasley committed
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
    }
    args = args_from_dict(tmpdir, config_dict)
    hidden_dim = 10

    model = SimpleModel(hidden_dim, empty_grad=False)

    @distributed_test(world_size=[2])
    def _test_checkpoint_no_lr_scheduler(args,
                                         model,
                                         hidden_dim,
                                         load_optimizer_states,
                                         load_lr_scheduler_states):
        checkpoint_correctness_verification(
            args,
            model,
            hidden_dim,
            tmpdir,
            load_optimizer_states=load_optimizer_states,
            load_lr_scheduler_states=load_lr_scheduler_states)

    _test_checkpoint_no_lr_scheduler(args=args,
                                     model=model,
                                     hidden_dim=hidden_dim,
                                     load_optimizer_states=False,
                                     load_lr_scheduler_states=False)
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547


def test_checkpoint_fp32_optimizer(tmpdir):
    config_dict = {
        "train_batch_size": 2,
        "steps_per_print": 1,
        "optimizer": {
            "type": "Adam",
            "params": {
                "lr": 0.00015,
                "betas": [0.8,
                          0.999],
                "eps": 1e-8,
                "weight_decay": 3e-7
            }
        },
        "fp16": {
            "enabled": False
        }
    }

    args = args_from_dict(tmpdir, config_dict)
    hidden_dim = 10

    model = SimpleModel(hidden_dim, empty_grad=False)

    @distributed_test(world_size=[2])
    def _test_checkpoint_fp32_optimizer(args, model, hidden_dim):
        checkpoint_correctness_verification(args, model, hidden_dim, tmpdir, fp16=False)

    _test_checkpoint_fp32_optimizer(args=args, model=model, hidden_dim=hidden_dim)
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652


@pytest.mark.parametrize("zero_stage", [0, 1])
def test_checkpoint_pipe_engine(zero_stage, tmpdir, stages=2):
    config_dict = {
        "train_batch_size": 2,
        "train_micro_batch_size_per_gpu": 1,
        "steps_per_print": 1,
        "optimizer": {
            "type": "Adam",
            "params": {
                "lr": 1e-5
            }
        },
        "zero_optimization": {
            "stage": zero_stage
        },
        "fp16": {
            "enabled": zero_stage > 0
        },
        "scheduler": {
            "type": "OneCycle",
            "params": {
                "cycle_first_step_size": 1000,
                "cycle_first_stair_count": 500,
                "cycle_second_step_size": 1000,
                "cycle_second_stair_count": 500,
                "decay_step_size": 1000,
                "cycle_min_lr": 0.0001,
                "cycle_max_lr": 0.0010,
                "decay_lr_rate": 0.001,
                "cycle_min_mom": 0.85,
                "cycle_max_mom": 0.99,
                "decay_mom_rate": 0.0
            }
        }
    }

    @distributed_test(world_size=4)
    def _test(save_folder, num_stages):
        args = args_from_dict(tmpdir, config_dict)
        model = LinearStackPipe(num_stages=num_stages)
        checkpoint_correctness_verification(args=args,
                                            model=model,
                                            hidden_dim=model.hidden_dim,
                                            tmpdir=save_folder,
                                            fp16=config_dict['fp16']['enabled'],
                                            load_optimizer_states=True,
                                            load_lr_scheduler_states=True,
                                            train_batch=True)

    _test(tmpdir, num_stages=stages)


@pytest.mark.parametrize("base_topo,test_topo",
                         [
                             (PipeTopo(num_pp=1,
                                       num_dp=4),
                              PipeTopo(num_pp=4,
                                       num_dp=1)),
                             (PipeTopo(num_pp=2,
                                       num_dp=2),
                              PipeTopo(num_pp=2,
                                       num_dp=2)),
                             (PipeTopo(num_pp=4,
                                       num_dp=1),
                              PipeTopo(num_pp=2,
                                       num_dp=2)),
                         ])
def test_checkpoint_pipe_module(base_topo, test_topo, tmpdir):
    @distributed_test(world_size=4)
    def _test(base_topo, test_topo, save_folder):
        base_model = LinearStackPipe(topology=base_topo)
        base_model.save_state_dict(save_folder)

        dist.barrier()

        test_model = LinearStackPipe(topology=test_topo)
        test_model.load_state_dir(save_folder)

        # Base and test can have different lengths, so make sure we map from the
        # smaller to larger model
        if len(base_model.forward_funcs) < len(test_model.forward_funcs):
            A = base_model
            B = test_model
        else:
            A = test_model
            B = base_model

        # Compare layers individually since partitions are different
        for idx, A_layer in enumerate(A.forward_funcs):
            if not hasattr(A_layer, 'parameters'):
                # Skip functionals, etc.
                continue

            # Find the corresponding layer in B
            global_idx = idx + A._local_start
            B_local_idx = global_idx - B._local_start
            B_layer = B.forward_funcs[B_local_idx]

            # Compare layer parameters
            for p0, p1 in zip(A_layer.parameters(), B_layer.parameters()):
                assert torch.allclose(p0, p1, atol=1e-07), f"Model state {p0} is not equal to {p1}"

    _test(base_topo, test_topo, save_folder=tmpdir)