test_checkpointing.py 22.8 KB
Newer Older
1
import torch
2
3
import torch.distributed as dist

4
import deepspeed
5
6
from deepspeed.runtime.zero.stage2 import FP16_DeepSpeedZeroOptimizer
from deepspeed.runtime.zero.stage1 import FP16_DeepSpeedZeroOptimizer_Stage1
7

8
9
from deepspeed.runtime.fp16.fused_optimizer import FP16_Optimizer
from deepspeed.runtime.fp16.unfused_optimizer import FP16_UnfusedOptimizer
10

11
12
13
from deepspeed.runtime.pipe.topology import *
PipeTopo = PipeDataParallelTopology

14
15
16
17
import argparse
import pytest
import json
import os
Jeff Rasley's avatar
Jeff Rasley committed
18
import numbers
19
from common import distributed_test
20
from simple_model import *
21
22


23
24
25
26
27
28
29
30
31
def compare_deepspeed_states(saved_model, loaded_model):
    # These are compared in more depth in other places
    assert hasattr(loaded_model, 'module')

    assert saved_model.csr_tensor_module_names == loaded_model.csr_tensor_module_names
    assert saved_model.skipped_steps == loaded_model.skipped_steps
    assert saved_model.global_steps == loaded_model.global_steps


32
def compare_model_states(saved_model, loaded_model, compare_optimizer=True):
33
34
    compare_deepspeed_states(saved_model, loaded_model)

35
    for p0, p1 in zip(saved_model.module.parameters(), loaded_model.module.parameters()):
Jeff Rasley's avatar
Jeff Rasley committed
36
        assert torch.allclose(p0, p1, atol=1e-07), f"FP16 model state {p0} is not equal to {p1}"
37

38
39
40
    if not compare_optimizer:
        return

41
42
    if isinstance(saved_model.optimizer, FP16_DeepSpeedZeroOptimizer):
        for p0, p1 in zip(saved_model.optimizer.single_partition_of_fp32_groups, loaded_model.optimizer.single_partition_of_fp32_groups):
Jeff Rasley's avatar
Jeff Rasley committed
43
            assert torch.allclose(p0, p1, atol=1e-07), f"Fp32 model states {p0} is not equal to {p1}"
44

Jeff Rasley's avatar
Jeff Rasley committed
45
46
47
    elif isinstance(saved_model.optimizer, FP16_DeepSpeedZeroOptimizer_Stage1):
        for partition0, partition1 in zip(saved_model.optimizer.local_sub_partitions_of_fp32_groups, loaded_model.optimizer.local_sub_partitions_of_fp32_groups):
            for p0, p1 in zip(partition0, partition1):
Jeff Rasley's avatar
Jeff Rasley committed
48
                assert torch.allclose(p0, p1, atol=1e-07), f"Fp32 model states {p0} is not equal to {p1}"
Jeff Rasley's avatar
Jeff Rasley committed
49

50
51
    elif isinstance(saved_model.optimizer, FP16_Optimizer):
        for p0, p1 in zip(saved_model.optimizer.fp32_groups_flat, loaded_model.optimizer.fp32_groups_flat):
Jeff Rasley's avatar
Jeff Rasley committed
52
            assert torch.allclose(p0, p1, atol=1e-07), f"FP32 model states {p0} is not equal to {p1}"
53
54
55
56

    elif isinstance(saved_model.optimizer, FP16_UnfusedOptimizer):
        for params0, params1 in zip(saved_model.optimizer.fp32_groups, loaded_model.optimizer.fp32_groups):
            for p0, p1 in zip(params0, params1):
Jeff Rasley's avatar
Jeff Rasley committed
57
                assert torch.allclose(p0, p1, atol=1e-07), f"FP32 model states {p0} is not equal to {p1}"
58
59
    elif isinstance(saved_model.optimizer, torch.optim.Optimizer):
        pass
60
    else:
61
62
        assert False, f'Unexpected Optimizer Type: {saved_model.optimizer}'

63

64
65
66
def compare_optimizer_states(saved_model, loaded_model, hidden_dim, fp16=True):
    saved_optimizer = saved_model.optimizer.optimizer if fp16 else saved_model.optimizer
    loaded_optimizer = loaded_model.optimizer.optimizer if fp16 else loaded_model.optimizer
67

68
69
    for state0, state1 in zip(saved_optimizer.state.values(),
                              loaded_optimizer.state.values()):
70
71
72
73
74
75
76
        for s0, s1 in zip(state0.values(), state1.values()):
            if isinstance(s0, torch.Tensor) and isinstance(s1, torch.Tensor):
                assert torch.equal(s0, s1)
            else:
                assert s0 == s1


Jeff Rasley's avatar
Jeff Rasley committed
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
def compare_lr_scheduler_states(saved_model, loaded_model):
    assert hasattr(saved_model, 'lr_scheduler')
    assert hasattr(loaded_model, 'lr_scheduler')

    saved_scheduler = saved_model.lr_scheduler
    loaded_scheduler = loaded_model.lr_scheduler

    assert hasattr(saved_scheduler, 'state_dict')
    assert hasattr(loaded_scheduler, 'state_dict')

    saved_sd = saved_scheduler.state_dict()
    loaded_sd = loaded_scheduler.state_dict()

    print(f"saved_sd = {saved_sd}")
    print(f"loaded_sd = {loaded_sd}")

    assert saved_sd.keys() == loaded_sd.keys()

    for state0, state1 in zip(saved_sd.values(), loaded_sd.values()):
        if isinstance(state0, numbers.Number) and isinstance(state1, numbers.Number):
            assert state0 == state1


def checkpoint_correctness_verification(args,
101
102
                                        model,
                                        hidden_dim,
Jeff Rasley's avatar
Jeff Rasley committed
103
104
                                        tmpdir,
                                        load_optimizer_states=False,
105
                                        load_lr_scheduler_states=False,
106
107
                                        fp16=True,
                                        train_batch=False):
108
    dtype = torch.half if fp16 else torch.float32
Jeff Rasley's avatar
Jeff Rasley committed
109
110
111
    ds_model, _, _, _ = deepspeed.initialize(args=args,
                                             model=model,
                                             model_parameters=model.parameters())
112
113
114
    data_loader = random_dataloader(model=ds_model,
                                    total_samples=50,
                                    hidden_dim=hidden_dim,
115
116
                                    device=ds_model.device,
                                    dtype=dtype)
117
118
119
120
121
122
123
124
125
126
127

    if train_batch:
        ds_model.set_dataloader(data_loader)
        for n, batch in enumerate(data_loader):
            loss = ds_model.train_batch()
    else:
        for n, batch in enumerate(data_loader):
            loss = ds_model(batch[0], batch[1])
            print(loss)
            ds_model.backward(loss)
            ds_model.step()
128
129
130

    trained_model = ds_model

Jeff Rasley's avatar
Jeff Rasley committed
131
    save_folder = os.path.join(tmpdir, 'saved_checkpoint')
132
133
134
135
    save_tag = '1'

    trained_model.save_checkpoint(save_folder, save_tag)

Jeff Rasley's avatar
Jeff Rasley committed
136
137
138
    loaded_model, _, _, _ = deepspeed.initialize(args=args,
                                                 model=model,
                                                 model_parameters=model.parameters())
139
140
141

    loaded_model.load_checkpoint(save_folder,
                                 save_tag,
Jeff Rasley's avatar
Jeff Rasley committed
142
143
                                 load_optimizer_states=load_optimizer_states,
                                 load_lr_scheduler_states=load_lr_scheduler_states)
144

Jeff Rasley's avatar
Jeff Rasley committed
145
    compare_model_states(trained_model, loaded_model)
146

147
    if load_optimizer_states:
148
        compare_optimizer_states(trained_model, loaded_model, hidden_dim, fp16)
Jeff Rasley's avatar
Jeff Rasley committed
149
150
151

    if load_lr_scheduler_states:
        compare_lr_scheduler_states(trained_model, loaded_model)
152
153
154
155
156
157
158
159
160


def test_checkpoint_unfused_optimizer(tmpdir):
    config_dict = {
        "train_batch_size": 2,
        "steps_per_print": 1,
        "optimizer": {
            "type": "Lamb",
            "params": {
161
                "lr": 0.00015
162
163
            }
        },
164
        "gradient_clipping": 1.0,
165
166
        "fp16": {
            "enabled": True
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
        },
        "scheduler": {
            "type": "OneCycle",
            "params": {
                "cycle_first_step_size": 1000,
                "cycle_first_stair_count": 500,
                "cycle_second_step_size": 1000,
                "cycle_second_stair_count": 500,
                "decay_step_size": 1000,
                "cycle_min_lr": 0.0001,
                "cycle_max_lr": 0.0010,
                "decay_lr_rate": 0.001,
                "cycle_min_mom": 0.85,
                "cycle_max_mom": 0.99,
                "decay_mom_rate": 0.0
            }
183
184
185
186
187
188
189
190
191
192
193
194
195
        }
    }

    args = args_from_dict(tmpdir, config_dict)
    hidden_dim = 10

    model = SimpleModel(hidden_dim, empty_grad=False)

    @distributed_test(world_size=[2])
    def _test_checkpoint_unfused_optimizer(args,
                                           model,
                                           hidden_dim,
                                           load_optimizer_states):
Jeff Rasley's avatar
Jeff Rasley committed
196
        checkpoint_correctness_verification(args,
197
198
                                            model,
                                            hidden_dim,
Jeff Rasley's avatar
Jeff Rasley committed
199
                                            tmpdir,
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
                                            load_optimizer_states=load_optimizer_states)

    _test_checkpoint_unfused_optimizer(args=args,
                                       model=model,
                                       hidden_dim=hidden_dim,
                                       load_optimizer_states=True)
    _test_checkpoint_unfused_optimizer(args=args,
                                       model=model,
                                       hidden_dim=hidden_dim,
                                       load_optimizer_states=False)


def test_checkpoint_fused_optimizer(tmpdir):
    config_dict = {
        "train_batch_size": 2,
        "steps_per_print": 1,
        "optimizer": {
            "type": "Adam",
            "params": {
                "lr": 0.00015,
                "betas": [0.8,
                          0.999],
                "eps": 1e-8,
                "weight_decay": 3e-7
            }
        },
        "fp16": {
            "enabled": True
        }
    }

    args = args_from_dict(tmpdir, config_dict)
    hidden_dim = 10

    model = SimpleModel(hidden_dim, empty_grad=False)

    @distributed_test(world_size=[2])
    def _test_checkpoint_fused_optimizer(args, model, hidden_dim, load_optimizer_states):
Jeff Rasley's avatar
Jeff Rasley committed
238
        checkpoint_correctness_verification(args,
239
240
                                            model,
                                            hidden_dim,
Jeff Rasley's avatar
Jeff Rasley committed
241
                                            tmpdir,
242
243
244
245
246
247
248
249
250
251
252
253
                                            load_optimizer_states=load_optimizer_states)

    _test_checkpoint_fused_optimizer(args=args,
                                     model=model,
                                     hidden_dim=hidden_dim,
                                     load_optimizer_states=True)
    _test_checkpoint_fused_optimizer(args=args,
                                     model=model,
                                     hidden_dim=hidden_dim,
                                     load_optimizer_states=False)


Jeff Rasley's avatar
Jeff Rasley committed
254
255
256
257
258
259
260
261
262
263
264
265
266
@pytest.mark.parametrize('zero_stage, use_cpu_offload, adam_optimizer',
                         [
                             (1,
                              False,
                              'Adam'),
                             (2,
                              False,
                              'Adam'),
                             (2,
                              True,
                              'deepspeed_adam'),
                         ])
def test_checkpoint_zero_optimizer(tmpdir, zero_stage, use_cpu_offload, adam_optimizer):
267
268
269
270
    config_dict = {
        "train_batch_size": 2,
        "steps_per_print": 1,
        "optimizer": {
Jeff Rasley's avatar
Jeff Rasley committed
271
            "type": adam_optimizer,
272
273
274
275
276
277
278
279
280
281
282
            "params": {
                "lr": 0.00015,
                "betas": [0.8,
                          0.999],
                "eps": 1e-8,
                "weight_decay": 3e-7
            }
        },
        "fp16": {
            "enabled": True
        },
Jeff Rasley's avatar
Jeff Rasley committed
283
        "zero_optimization": {
Jeff Rasley's avatar
Jeff Rasley committed
284
285
286
            "stage": zero_stage,
            "cpu_offload": use_cpu_offload
        }
287
288
289
290
291
292
293
294
    }
    args = args_from_dict(tmpdir, config_dict)
    hidden_dim = 10

    model = SimpleModel(hidden_dim, empty_grad=False)

    @distributed_test(world_size=[2])
    def _test_checkpoint_zero_optimizer(args, model, hidden_dim, load_optimizer_states):
Jeff Rasley's avatar
Jeff Rasley committed
295
        checkpoint_correctness_verification(args,
296
297
                                            model,
                                            hidden_dim,
Jeff Rasley's avatar
Jeff Rasley committed
298
                                            tmpdir,
299
300
301
302
303
304
                                            load_optimizer_states=load_optimizer_states)

    _test_checkpoint_zero_optimizer(args=args,
                                    model=model,
                                    hidden_dim=hidden_dim,
                                    load_optimizer_states=True)
Jeff Rasley's avatar
Jeff Rasley committed
305
306


Jeff Rasley's avatar
Jeff Rasley committed
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
@pytest.mark.parametrize('zero_stage, use_cpu_offload, adam_optimizer',
                         [
                             (1,
                              False,
                              "Adam"),
                             (2,
                              False,
                              "Adam"),
                             (2,
                              True,
                              'deepspeed_adam'),
                         ])
def test_checkpoint_zero_no_optimizer(tmpdir,
                                      zero_stage,
                                      use_cpu_offload,
                                      adam_optimizer):
Jeff Rasley's avatar
Jeff Rasley committed
323
324
325
326
    config_dict = {
        "train_batch_size": 2,
        "steps_per_print": 1,
        "optimizer": {
Jeff Rasley's avatar
Jeff Rasley committed
327
            "type": adam_optimizer,
Jeff Rasley's avatar
Jeff Rasley committed
328
329
330
331
332
333
334
335
336
337
338
339
            "params": {
                "lr": 0.00015,
                "betas": [0.8,
                          0.999],
                "eps": 1e-8,
                "weight_decay": 3e-7
            }
        },
        "fp16": {
            "enabled": True
        },
        "zero_optimization": {
Jeff Rasley's avatar
Jeff Rasley committed
340
341
342
            "stage": zero_stage,
            "cpu_offload": use_cpu_offload
        }
Jeff Rasley's avatar
Jeff Rasley committed
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
    }
    args = args_from_dict(tmpdir, config_dict)
    hidden_dim = 10

    model = SimpleModel(hidden_dim, empty_grad=False)

    @distributed_test(world_size=[2])
    def _test_checkpoint_zero_no_optimizer(args,
                                           model,
                                           hidden_dim,
                                           load_optimizer_states):
        checkpoint_correctness_verification(args,
                                            model,
                                            hidden_dim,
                                            tmpdir,
                                            load_optimizer_states=load_optimizer_states)

    _test_checkpoint_zero_no_optimizer(args=args,
                                       model=model,
                                       hidden_dim=hidden_dim,
                                       load_optimizer_states=False)


Jeff Rasley's avatar
Jeff Rasley committed
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
@pytest.mark.parametrize('zero_stage, use_cpu_offload, adam_optimizer',
                         [
                             (0,
                              False,
                              'Adam'),
                             (1,
                              False,
                              'Adam'),
                             (2,
                              False,
                              'Adam'),
                             (2,
                              True,
                              'deepspeed_adam'),
                         ])
def test_checkpoint_lr_scheduler(tmpdir, zero_stage, use_cpu_offload, adam_optimizer):
Jeff Rasley's avatar
Jeff Rasley committed
382
383
384
385
    config_dict = {
        "train_batch_size": 2,
        "steps_per_print": 1,
        "optimizer": {
Jeff Rasley's avatar
Jeff Rasley committed
386
            "type": adam_optimizer,
Jeff Rasley's avatar
Jeff Rasley committed
387
388
389
390
391
392
393
394
395
396
397
398
            "params": {
                "lr": 0.00015,
                "betas": [0.8,
                          0.999],
                "eps": 1e-8,
                "weight_decay": 3e-7
            }
        },
        "fp16": {
            "enabled": True
        },
        "zero_optimization": {
Jeff Rasley's avatar
Jeff Rasley committed
399
400
            "stage": zero_stage,
            "cpu_offload": use_cpu_offload
Jeff Rasley's avatar
Jeff Rasley committed
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
        },
        "scheduler": {
            "type": "WarmupLR",
            "params": {
                "warmup_min_lr": 0,
                "warmup_max_lr": 0.001,
                "warmup_num_steps": 1000
            }
        }
    }
    args = args_from_dict(tmpdir, config_dict)
    hidden_dim = 10

    model = SimpleModel(hidden_dim, empty_grad=False)

    @distributed_test(world_size=[2])
    def _test_checkpoint_lr_scheduler(args,
                                      model,
                                      hidden_dim,
                                      load_optimizer_states,
                                      load_lr_scheduler_states):
        checkpoint_correctness_verification(
            args,
            model,
            hidden_dim,
            tmpdir,
            load_optimizer_states=load_optimizer_states,
            load_lr_scheduler_states=load_lr_scheduler_states)

    _test_checkpoint_lr_scheduler(args=args,
                                  model=model,
                                  hidden_dim=hidden_dim,
                                  load_optimizer_states=False,
                                  load_lr_scheduler_states=True)


Jeff Rasley's avatar
Jeff Rasley committed
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
@pytest.mark.parametrize('zero_stage, use_cpu_offload, adam_optimizer',
                         [
                             (0,
                              False,
                              'Adam'),
                             (1,
                              False,
                              'Adam'),
                             (2,
                              False,
                              'Adam'),
                             (2,
                              True,
                              'deepspeed_adam'),
                         ])
def test_checkpoint_no_lr_scheduler(tmpdir, zero_stage, use_cpu_offload, adam_optimizer):
Jeff Rasley's avatar
Jeff Rasley committed
453
454
455
456
    config_dict = {
        "train_batch_size": 2,
        "steps_per_print": 1,
        "optimizer": {
Jeff Rasley's avatar
Jeff Rasley committed
457
            "type": adam_optimizer,
Jeff Rasley's avatar
Jeff Rasley committed
458
459
460
461
462
463
464
465
            "params": {
                "lr": 1e-5
            }
        },
        "fp16": {
            "enabled": True
        },
        "zero_optimization": {
Jeff Rasley's avatar
Jeff Rasley committed
466
467
            "stage": zero_stage,
            "cpu_offload": use_cpu_offload
Jeff Rasley's avatar
Jeff Rasley committed
468
469
470
471
472
473
474
475
        },
        "scheduler": {
            "type": "WarmupLR",
            "params": {
                "warmup_min_lr": 0,
                "warmup_max_lr": 0.001,
                "warmup_num_steps": 1000
            }
Jeff Rasley's avatar
Jeff Rasley committed
476
        },
Jeff Rasley's avatar
Jeff Rasley committed
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
    }
    args = args_from_dict(tmpdir, config_dict)
    hidden_dim = 10

    model = SimpleModel(hidden_dim, empty_grad=False)

    @distributed_test(world_size=[2])
    def _test_checkpoint_no_lr_scheduler(args,
                                         model,
                                         hidden_dim,
                                         load_optimizer_states,
                                         load_lr_scheduler_states):
        checkpoint_correctness_verification(
            args,
            model,
            hidden_dim,
            tmpdir,
            load_optimizer_states=load_optimizer_states,
            load_lr_scheduler_states=load_lr_scheduler_states)

    _test_checkpoint_no_lr_scheduler(args=args,
                                     model=model,
                                     hidden_dim=hidden_dim,
                                     load_optimizer_states=False,
                                     load_lr_scheduler_states=False)
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532


def test_checkpoint_fp32_optimizer(tmpdir):
    config_dict = {
        "train_batch_size": 2,
        "steps_per_print": 1,
        "optimizer": {
            "type": "Adam",
            "params": {
                "lr": 0.00015,
                "betas": [0.8,
                          0.999],
                "eps": 1e-8,
                "weight_decay": 3e-7
            }
        },
        "fp16": {
            "enabled": False
        }
    }

    args = args_from_dict(tmpdir, config_dict)
    hidden_dim = 10

    model = SimpleModel(hidden_dim, empty_grad=False)

    @distributed_test(world_size=[2])
    def _test_checkpoint_fp32_optimizer(args, model, hidden_dim):
        checkpoint_correctness_verification(args, model, hidden_dim, tmpdir, fp16=False)

    _test_checkpoint_fp32_optimizer(args=args, model=model, hidden_dim=hidden_dim)
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637


@pytest.mark.parametrize("zero_stage", [0, 1])
def test_checkpoint_pipe_engine(zero_stage, tmpdir, stages=2):
    config_dict = {
        "train_batch_size": 2,
        "train_micro_batch_size_per_gpu": 1,
        "steps_per_print": 1,
        "optimizer": {
            "type": "Adam",
            "params": {
                "lr": 1e-5
            }
        },
        "zero_optimization": {
            "stage": zero_stage
        },
        "fp16": {
            "enabled": zero_stage > 0
        },
        "scheduler": {
            "type": "OneCycle",
            "params": {
                "cycle_first_step_size": 1000,
                "cycle_first_stair_count": 500,
                "cycle_second_step_size": 1000,
                "cycle_second_stair_count": 500,
                "decay_step_size": 1000,
                "cycle_min_lr": 0.0001,
                "cycle_max_lr": 0.0010,
                "decay_lr_rate": 0.001,
                "cycle_min_mom": 0.85,
                "cycle_max_mom": 0.99,
                "decay_mom_rate": 0.0
            }
        }
    }

    @distributed_test(world_size=4)
    def _test(save_folder, num_stages):
        args = args_from_dict(tmpdir, config_dict)
        model = LinearStackPipe(num_stages=num_stages)
        checkpoint_correctness_verification(args=args,
                                            model=model,
                                            hidden_dim=model.hidden_dim,
                                            tmpdir=save_folder,
                                            fp16=config_dict['fp16']['enabled'],
                                            load_optimizer_states=True,
                                            load_lr_scheduler_states=True,
                                            train_batch=True)

    _test(tmpdir, num_stages=stages)


@pytest.mark.parametrize("base_topo,test_topo",
                         [
                             (PipeTopo(num_pp=1,
                                       num_dp=4),
                              PipeTopo(num_pp=4,
                                       num_dp=1)),
                             (PipeTopo(num_pp=2,
                                       num_dp=2),
                              PipeTopo(num_pp=2,
                                       num_dp=2)),
                             (PipeTopo(num_pp=4,
                                       num_dp=1),
                              PipeTopo(num_pp=2,
                                       num_dp=2)),
                         ])
def test_checkpoint_pipe_module(base_topo, test_topo, tmpdir):
    @distributed_test(world_size=4)
    def _test(base_topo, test_topo, save_folder):
        base_model = LinearStackPipe(topology=base_topo)
        base_model.save_state_dict(save_folder)

        dist.barrier()

        test_model = LinearStackPipe(topology=test_topo)
        test_model.load_state_dir(save_folder)

        # Base and test can have different lengths, so make sure we map from the
        # smaller to larger model
        if len(base_model.forward_funcs) < len(test_model.forward_funcs):
            A = base_model
            B = test_model
        else:
            A = test_model
            B = base_model

        # Compare layers individually since partitions are different
        for idx, A_layer in enumerate(A.forward_funcs):
            if not hasattr(A_layer, 'parameters'):
                # Skip functionals, etc.
                continue

            # Find the corresponding layer in B
            global_idx = idx + A._local_start
            B_local_idx = global_idx - B._local_start
            B_layer = B.forward_funcs[B_local_idx]

            # Compare layer parameters
            for p0, p1 in zip(A_layer.parameters(), B_layer.parameters()):
                assert torch.allclose(p0, p1, atol=1e-07), f"Model state {p0} is not equal to {p1}"

    _test(base_topo, test_topo, save_folder=tmpdir)