test_serialization.py 36.4 KB
Newer Older
silencealiang's avatar
silencealiang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.

import io
import logging
import os

import numpy as np
import pytest
import torch
from torch.distributed.checkpoint import CheckpointException as PyTCheckpointingException
from torch.distributed.checkpoint import FileSystemReader

try:
    from torch.distributed import DeviceMesh
    from torch.distributed._tensor import DTensor

    HAVE_DTENSOR = True
except ImportError:
    HAVE_DTENSOR = False

from megatron.core import parallel_state
from megatron.core.dist_checkpointing import ShardedTensor, load, remove_sharded_tensors, save
from megatron.core.dist_checkpointing.core import CheckpointingException, maybe_load_config
from megatron.core.dist_checkpointing.dict_utils import diff
from megatron.core.dist_checkpointing.mapping import ShardedObject, ShardedTensorFactory
from megatron.core.dist_checkpointing.serialization import (
    load_sharded_metadata,
    load_tensors_metadata,
)
from megatron.core.dist_checkpointing.strategies.base import StrategyAction, get_default_strategy
from megatron.core.dist_checkpointing.strategies.torch import TorchDistSaveShardedStrategy
from megatron.core.dist_checkpointing.validation import StrictHandling
from megatron.core.utils import is_torch_min_version
from tests.unit_tests.dist_checkpointing import TempNamedDir
from tests.unit_tests.test_utilities import Utils


class TestSerialization:
    def setup_method(self, method):
        pass

    def teardown_method(self, method):
        Utils.destroy_model_parallel()

    def test_single_process_save_load(self, tmp_path_dist_ckpt):
        Utils.initialize_model_parallel(1, 1)

        sharded_state_dict = {
            'sd_keyA': ShardedTensor.from_rank_offsets(
                'keyA', torch.ones(2, 4), replica_id=Utils.rank
            ),
            'sd_keyB': ShardedTensor.from_rank_offsets(
                'keyB', torch.ones(3, 5, 7), replica_id=Utils.rank
            ),
        }

        if HAVE_DTENSOR:
            mesh = DeviceMesh.from_group(
                parallel_state.get_data_parallel_group(with_context_parallel=True), "cuda"
            )
            sharded_state_dict['sd_keyD'] = ShardedTensor.from_rank_offsets(
                'keyD',
                DTensor.from_local(torch.ones(3, 5, 7), mesh)._local_tensor,
                replica_id=Utils.rank,
            )

        # sync=True to make sure other ranks wait for rank 0 to finish creating directory.
        with TempNamedDir(
            tmp_path_dist_ckpt / 'test_single_process_save_load', sync=True
        ) as ckpt_dir:
            save(sharded_state_dict, ckpt_dir)
            torch.distributed.barrier()

            saved_config = maybe_load_config(ckpt_dir)
            if saved_config.sharded_backend == 'zarr':
                assert (ckpt_dir / 'keyA').is_dir()
                assert (ckpt_dir / 'keyB').is_dir()
                assert not (ckpt_dir / 'keyC').exists()
                assert not (ckpt_dir / 'sd_keyA').is_dir()

                if HAVE_DTENSOR:
                    assert (ckpt_dir / 'keyD').is_dir()

            load_ssd = {
                'load_sd_keyA': ShardedTensor.from_rank_offsets(
                    'keyA', torch.ones(2, 4), replica_id=Utils.rank
                )
            }
            loaded_state_dict = load(load_ssd, ckpt_dir)

            assert set(loaded_state_dict.keys()) == {'load_sd_keyA'}
            assert isinstance(loaded_state_dict['load_sd_keyA'], torch.Tensor)
            assert loaded_state_dict['load_sd_keyA'].shape == (2, 4)

        Utils.destroy_model_parallel()

    def test_multi_process_save(self, tmp_path_dist_ckpt):
        Utils.initialize_model_parallel(2, 4)

        state_dict = {
            'sd_keyA': ShardedTensor.from_rank_offsets(
                'keyA', torch.ones(2, 4), (0, Utils.rank, Utils.world_size)
            ),
            'sd_keyB': ShardedTensor.from_rank_offsets(
                'keyB', torch.ones(3, 5, 7), (2, Utils.rank, Utils.world_size)
            ),
            'lr': 0.01,
            'rank': torch.distributed.get_rank(),
        }

        def preprocess_fn(x):
            del x['rank']
            return x

        # sync=True to make sure other ranks wait for rank 0 to finish creating directory.
        with TempNamedDir(tmp_path_dist_ckpt / 'test_multi_process_save', sync=True) as ckpt_dir:
            save(
                state_dict,
                ckpt_dir,
                validate_access_integrity=True,
                preprocess_common_before_consistancy_check=preprocess_fn,
            )

            saved_config = maybe_load_config(ckpt_dir)
            if saved_config.sharded_backend == 'zarr':
                assert (ckpt_dir / 'keyA').is_dir()
                assert (ckpt_dir / 'keyB').is_dir()
                assert not (ckpt_dir / 'keyC').exists()
                assert not (ckpt_dir / 'sd_keyA').is_dir()

        Utils.destroy_model_parallel()

    def test_multi_process_save_log_difference(self, tmp_path_dist_ckpt, caplog):
        Utils.initialize_model_parallel(2, 4)

        state_dict = {
            'sd_keyA': ShardedTensor.from_rank_offsets(
                'keyA', torch.ones(2, 4), (0, Utils.rank, Utils.world_size)
            ),
            'sd_keyB': ShardedTensor.from_rank_offsets(
                'keyB', torch.ones(3, 5, 7), (2, Utils.rank, Utils.world_size)
            ),
            'rank': torch.distributed.get_rank(),
        }

        def preprocess_fn(x):
            return x

        with caplog.at_level(logging.WARNING):
            # sync=True to make sure other ranks wait for rank 0 to finish creating directory.
            with TempNamedDir(
                tmp_path_dist_ckpt / 'test_multi_process_save', sync=True
            ) as ckpt_dir:
                save(
                    state_dict,
                    ckpt_dir,
                    validate_access_integrity=True,
                    preprocess_common_before_consistancy_check=preprocess_fn,
                )
            # pylint: disable=line-too-long
            if torch.distributed.get_rank() == 0:
                assert (
                    "There is difference in the common state dict in different ranks. The differences are {1: ([], [], [(('rank',), <class 'int'>, <class 'int'>)]), 2: ([], [], [(('rank',), <class 'int'>, <class 'int'>)]), 3: ([], [], [(('rank',), <class 'int'>, <class 'int'>)]), 4: ([], [], [(('rank',), <class 'int'>, <class 'int'>)]), 5: ([], [], [(('rank',), <class 'int'>, <class 'int'>)]), 6: ([], [], [(('rank',), <class 'int'>, <class 'int'>)]), 7: ([], [], [(('rank',), <class 'int'>, <class 'int'>)])}"
                    in caplog.text
                )

        Utils.destroy_model_parallel()

    def test_partition_change_save_load(self, tmp_path_dist_ckpt, strategy=None):
        Utils.initialize_model_parallel(2, 4)

        # ten_a: global shape (2, 4):
        ten_a_global = torch.tensor([[0, 1, 2, 3], [10, 11, 12, 13]])
        ten_a = (
            torch.zeros(1, 1)
            + 10 * parallel_state.get_tensor_model_parallel_rank()
            + parallel_state.get_pipeline_model_parallel_rank()
        )
        assert ten_a.shape == (1, 1)

        # ten_b: global shape (4, 5, 80), where (x, y, z) is (100x + z)
        ten_b = torch.zeros(4, 5, 10) + (torch.arange(10) + 10 * Utils.rank)
        ten_b += torch.arange(4).unsqueeze(-1).unsqueeze(-1) * 100
        assert ten_b.shape == (4, 5, 10)

        state_dict = {
            'sd_keyA': ShardedTensor.from_rank_offsets(
                'keyA',
                ten_a,
                (
                    0,
                    parallel_state.get_tensor_model_parallel_rank(),
                    parallel_state.get_tensor_model_parallel_world_size(),
                ),
                (
                    1,
                    parallel_state.get_pipeline_model_parallel_rank(),
                    parallel_state.get_pipeline_model_parallel_world_size(),
                ),
                replica_id=0,
            ),
            'sd_keyB': ShardedTensor.from_rank_offsets(
                'keyB', ten_b, (2, Utils.rank, Utils.world_size)
            ),
        }

        ten_a_global_shape = ten_a_global.shape
        ten_b_global_shape = (4, 5, 10 * 8)

        assert state_dict['sd_keyA'].local_shape == (1, 1)
        assert state_dict['sd_keyA'].global_shape == ten_a_global_shape
        assert state_dict['sd_keyB'].global_shape == ten_b_global_shape

        # sync=True to make sure other ranks wait for rank 0 to finish creating directory.
        with TempNamedDir(
            tmp_path_dist_ckpt / 'test_partition_change_save_load', sync=True
        ) as ckpt_dir:
            save(state_dict, ckpt_dir, strategy)

            del ten_a, ten_b

            # without changing TPxPP, load tensors without any sharding
            load_sd = {
                'sd_keyA': ShardedTensor.from_rank_offsets(
                    'keyA', torch.empty(ten_a_global_shape), replica_id=Utils.rank
                ),
                'sd_keyB': ShardedTensor.from_rank_offsets(
                    'keyB', torch.empty(ten_b_global_shape), replica_id=Utils.rank
                ),
            }
            loaded_state_dict = load(load_sd, ckpt_dir)

            ten_a = loaded_state_dict['sd_keyA']
            ten_b = loaded_state_dict['sd_keyB']
            assert isinstance(ten_a, torch.Tensor)
            assert ten_a.shape == ten_a_global_shape
            assert torch.all(ten_a == ten_a_global)

            assert isinstance(ten_b, torch.Tensor)
            assert ten_b.shape == ten_b_global_shape
            assert np.all(
                [
                    val == 100 * x + z
                    for x, x_row in enumerate(ten_b)
                    for y, y_row in enumerate(x_row)
                    for z, val in enumerate(y_row)
                ]
            )

            del ten_a, ten_b

            # change TPxPP
            Utils.destroy_model_parallel()
            Utils.initialize_model_parallel(1, 2)

            load_sd = {
                'sd_keyA': ShardedTensor.from_rank_offsets(
                    'keyA',
                    torch.empty(2, 1),
                    (
                        1,
                        parallel_state.get_data_parallel_rank(),
                        parallel_state.get_data_parallel_world_size(),
                    ),
                    replica_id=parallel_state.get_pipeline_model_parallel_rank(),
                ),
                'sd_keyB': ShardedTensor.from_rank_offsets(
                    'keyB',
                    torch.empty(5, 80),
                    (0, Utils.rank // 2, 4),
                    prepend_axis_num=1,
                    replica_id=Utils.rank % 2,
                ),
            }

            loaded_state_dict = load(load_sd, ckpt_dir)
            ten_a = loaded_state_dict['sd_keyA']
            ten_b = loaded_state_dict['sd_keyB']

            assert isinstance(ten_a, torch.Tensor)
            assert ten_a.shape == (2, 1)
            assert torch.all(
                ten_a[:, 0] == ten_a_global[:, parallel_state.get_data_parallel_rank()]
            )

            assert isinstance(ten_b, torch.Tensor)
            assert ten_b.shape == (5, 10 * 8)
            assert torch.all(
                ten_b == torch.arange(80).unsqueeze(0).expand(5, 80) + Utils.rank // 2 * 100
            )

    def test_load_tensors_metadata(self, tmp_path_dist_ckpt):
        Utils.initialize_model_parallel(2, 4)

        state_dict = {
            'sd_keyA': ShardedTensor.from_rank_offsets(
                'keyA', torch.arange(10) + Utils.rank * 10, (0, Utils.rank, Utils.world_size)
            ),
            'sd_keyB': ShardedTensor.from_rank_offsets(
                'keyB', torch.ones(3, 5, 7), (2, Utils.rank, Utils.world_size)
            ),
        }

        # sync=True to make sure other ranks wait for rank 0 to finish creating directory.
        with TempNamedDir(tmp_path_dist_ckpt / 'test_load_tensors_metadata', sync=True) as ckpt_dir:
            save(state_dict, ckpt_dir)

            del state_dict
            sharded_state_dict = load_tensors_metadata(ckpt_dir)
            # loaded dict keys are ShardedTensor keys!
            assert 'keyA' in sharded_state_dict
            assert 'sd_keyA' not in sharded_state_dict

            # Check metadata
            assert sharded_state_dict['keyA'].global_shape == (10 * Utils.world_size,)
            assert sharded_state_dict['keyB'].global_shape == (3, 5, 7 * Utils.world_size)
            assert sharded_state_dict['keyA'].local_shape == sharded_state_dict['keyA'].global_shape
            assert sharded_state_dict['keyB'].local_shape == sharded_state_dict['keyB'].global_shape
            assert sharded_state_dict['keyA'].global_offset == (0,)
            assert sharded_state_dict['keyB'].global_offset == (0, 0, 0)
            assert sharded_state_dict['keyA'].axis_fragmentations == (1,)
            assert sharded_state_dict['keyB'].axis_fragmentations == (1, 1, 1)
            assert sharded_state_dict['keyA'].replica_id == 0
            assert sharded_state_dict['keyB'].replica_id == 0

            # metadata dict can be loaded. We don't validate access because there are multiple replica_id=0
            state_dict = load(sharded_state_dict, ckpt_dir, validate_access_integrity=False)
            assert torch.all(state_dict['keyA'] == torch.arange(10 * Utils.world_size))

        Utils.destroy_model_parallel()

    def test_can_mix_sharded_tensors_and_factories(self, tmp_path_dist_ckpt):
        Utils.initialize_model_parallel(1, 1)

        def _build_fn(key, tensor, replica_id, flattened_range):
            assert flattened_range is None
            return [
                ShardedTensor.from_rank_offsets(key + 'part1', tensor, replica_id=replica_id),
                ShardedTensor.from_rank_offsets(key + 'part2', tensor, replica_id=replica_id),
                ShardedTensor.from_rank_offsets(key + 'part3', tensor, replica_id=replica_id),
            ]

        # state dict can be modified by dist_checkpointing.save, so two copies
        def get_sharded_state_dict(base=0):
            return {
                'all': [
                    ShardedTensor.from_rank_offsets(
                        'A', torch.arange(2) + base, replica_id=Utils.rank
                    ),
                    ShardedTensor.from_rank_offsets(
                        'B', torch.arange(3) + base, replica_id=Utils.rank
                    ),
                    ShardedTensor.from_rank_offsets(
                        'C', torch.arange(4) + base, replica_id=Utils.rank
                    ),
                    ShardedTensorFactory(
                        'D', torch.arange(5) + base, _build_fn, sum, replica_id=Utils.rank
                    ),
                ]
            }

        # sync=True to make sure other ranks wait for rank 0 to finish creating directory.
        with TempNamedDir(
            tmp_path_dist_ckpt / 'test_can_mix_sharded_tensors_and_factories', sync=True
        ) as ckpt_dir:
            save(get_sharded_state_dict(0), ckpt_dir)
            loaded_state_dict = load(get_sharded_state_dict(10), ckpt_dir)

        expected_sd = {
            'all': [
                torch.arange(2),
                torch.arange(3),
                torch.arange(4),
                torch.arange(5) * 3,  # sum of three parts, as specified in merge_fn
            ]
        }
        diffs = diff(loaded_state_dict, expected_sd)
        assert not any(map(bool, diffs)), diffs

        Utils.destroy_model_parallel()

    def test_load_error_msg(self, tmp_path_dist_ckpt):
        ckpt_dir_name = 'test_load_error_msg'
        Utils.initialize_model_parallel(1, 1)
        sh_ten = ShardedTensor.from_rank_offsets('keyA', torch.rand(10), replica_id=Utils.rank)
        state_dict = {'some_key': sh_ten}

        # Non-existent directory
        non_ex_path = f'/tmp/non-existent-path/{ckpt_dir_name}'
        with pytest.raises(CheckpointingException) as exc_info:
            load(state_dict, non_ex_path)
        assert f'directory {non_ex_path} does not exist' in str(exc_info.value)

        # sync=True to make sure other ranks wait for rank 0 to finish creating directory.
        with TempNamedDir(tmp_path_dist_ckpt / ckpt_dir_name, sync=True) as ckpt_dir:
            # Empty directory - not a distributed checkpoint
            with pytest.raises(CheckpointingException) as exc_info:
                load(state_dict, ckpt_dir)
            assert f'is not a distributed checkpoint' in str(exc_info.value)

            # Missing Zarr arrays
            torch.distributed.barrier()
            save(state_dict, ckpt_dir)
            sh_ten.key = 'different_key'
            with pytest.raises((CheckpointingException, PyTCheckpointingException)) as exc_info:
                load(state_dict, ckpt_dir)
            assert "different_key" in str(exc_info.value)

    def test_sharded_object_serialization(self, tmp_path_dist_ckpt):
        Utils.initialize_model_parallel(1, 1)
        # sync=True to make sure other ranks wait for rank 0 to finish creating directory.
        with TempNamedDir(tmp_path_dist_ckpt / 'test_sh_obj', sync=True) as ckpt_dir:
            state = {'some': 'dict'}
            state_serialized = io.BytesIO()
            torch.save(state, state_serialized)
            state_dict = {
                'some_key': ShardedObject(
                    'sh_obj_A', state_serialized, (1,), (0,), replica_id=Utils.rank
                )
            }

            save(state_dict, ckpt_dir)
            del state, state_serialized, state_dict
            other_state = {'other': 'dictionary'}
            other_serialized = io.BytesIO()
            torch.save(other_state, other_serialized)
            state_dict = {
                'other_key': ShardedObject(
                    'sh_obj_A', other_serialized, (1,), (0,), replica_id=Utils.rank
                )
            }
            load_state_dict = load(state_dict, ckpt_dir)
            assert 'other_key' in load_state_dict
            load_state_dict['other_key'].seek(0)
            loaded_state = torch.load(load_state_dict['other_key'])

            assert loaded_state == {'some': 'dict'}

        Utils.destroy_model_parallel()

    def test_tensor_shape_mismatch(self, tmp_path_dist_ckpt):
        Utils.initialize_model_parallel(2, 4)

        # Global tensor is just a range(32) repeated twice over the first dimension
        local_tensor = torch.arange(4).unsqueeze(0).expand(2, 4) + Utils.rank * 4

        state_dict = {
            'rigid': ShardedTensor.from_rank_offsets(
                'keyA', local_tensor, (1, Utils.rank, Utils.world_size)
            ),
            'flexible': ShardedTensor.from_rank_offsets(
                'keyB', local_tensor, (1, Utils.rank, Utils.world_size), allow_shape_mismatch=True
            ),
        }
        assert state_dict['rigid'].global_shape == (2, 32)
        assert state_dict['flexible'].global_shape == (2, 32)

        # sync=True to make sure other ranks wait for rank 0 to finish creating directory.
        with TempNamedDir(tmp_path_dist_ckpt / 'test_tensor_shape_mismatch', sync=True) as ckpt_dir:
            save(state_dict, ckpt_dir)

            pp_size = parallel_state.get_pipeline_model_parallel_world_size()
            pp_rank = parallel_state.get_pipeline_model_parallel_rank()
            tp_rank = parallel_state.get_tensor_model_parallel_rank()

            # Smaller coverage than expected (28 < 32)
            state_dict = {
                'rigid': ShardedTensor.from_rank_offsets(
                    'keyA', torch.ones(2, 7), (1, pp_rank, pp_size), replica_id=tp_rank
                )
            }
            with pytest.raises((CheckpointingException, PyTCheckpointingException)):
                load(state_dict, ckpt_dir)

            state_dict = {
                'flexible': ShardedTensor.from_rank_offsets(
                    'keyB',
                    torch.ones(2, 7),
                    (1, pp_rank, pp_size),
                    replica_id=tp_rank,
                    allow_shape_mismatch=True,
                )
            }
            loaded_state_dict = load(state_dict, ckpt_dir)
            assert torch.all(
                loaded_state_dict['flexible']
                == torch.arange(7).unsqueeze(0).expand(2, 7) + pp_rank * 7
            )

            # Larger coverage than expected (36 > 32)
            state_dict = {
                'rigid': ShardedTensor.from_rank_offsets(
                    'keyA', torch.ones(2, 9), (1, pp_rank, pp_size), replica_id=tp_rank
                )
            }
            with pytest.raises((CheckpointingException, PyTCheckpointingException)):
                load(state_dict, ckpt_dir)

            state_dict = {
                'flexible': ShardedTensor.from_rank_offsets(
                    'keyB',
                    torch.ones(2, 9),
                    (1, pp_rank, pp_size),
                    replica_id=tp_rank,
                    allow_shape_mismatch=True,
                )
            }
            loaded_state_dict = load(state_dict, ckpt_dir)
            expected_tensor = torch.arange(9).unsqueeze(0).expand(2, 9) + pp_rank * 9

            if pp_rank >= (32 // 9):
                assert pp_rank == 3, pp_rank
                expected_tensor[:, 5:] = 0  # padding with 0s
            assert torch.all(loaded_state_dict['flexible'] == expected_tensor)

        Utils.destroy_model_parallel()

    @pytest.mark.skipif(
        not is_torch_min_version("2.3.0"),
        reason="remove_sharded_tensors relies on Torch APIs introduced in v2.3.0",
    )
    @pytest.mark.flaky_in_dev
    def test_remove_sharded_tensors(self, tmp_path_dist_ckpt):
        Utils.initialize_model_parallel(2, 4)

        # Global tensor is just a range(32) repeated twice over the first dimension
        global_tensor = torch.arange(4).unsqueeze(0).expand(2, 4)
        state_dict = {
            'sd_keyA': ShardedTensor.from_rank_offsets(
                'keyA', torch.ones(2, 4), (0, Utils.rank, Utils.world_size)
            ),
            'sd_prefix_key_to_remove': ShardedTensor.from_rank_offsets(
                'prefix_key_to_remove', torch.ones(3, 5, 7), (2, Utils.rank, Utils.world_size)
            ),
        }

        prefix_name = "prefix"  ## we will drop all tensors whose keys begin with "prefix"

        # sync=True to make sure other ranks wait for rank 0 to finish creating directory.
        with TempNamedDir(
            tmp_path_dist_ckpt / 'test_remove_sharded_tensor_prefix', sync=True
        ) as ckpt_dir:
            save_strategy = TorchDistSaveShardedStrategy(
                "torch_dist", 1, separation_hint=prefix_name
            )
            save(state_dict, ckpt_dir, save_strategy)

            files = os.listdir(ckpt_dir)
            prefix_files = [f for f in files if f.startswith(prefix_name)]
            assert len(prefix_files) == torch.distributed.get_world_size()

            fs_reader = FileSystemReader(ckpt_dir)
            original_metadata = fs_reader.read_metadata()
            assert set(original_metadata.state_dict_metadata.keys()) == {
                'keyA',
                'prefix_key_to_remove',
            }

            if torch.distributed.get_rank() == 0:
                remove_sharded_tensors(ckpt_dir, key_prefix=prefix_name)
            torch.distributed.barrier()

            files = os.listdir(ckpt_dir)
            prefix_files = [f for f in files if f.startswith(prefix_name)]
            assert len(prefix_files) == 0

            new_metadata = fs_reader.read_metadata()
            assert set(new_metadata.state_dict_metadata.keys()) == {'keyA'}

        Utils.destroy_model_parallel()

    def test_empty_load(self, tmp_path_dist_ckpt):
        Utils.initialize_model_parallel(2, 4)

        if Utils.rank == 0:
            state_dict = {'common': 'common-value'}
        elif Utils.rank == 1:
            state_dict = {'a': 3}  # this is not saved at all (common saved by rank 0 only)
        elif Utils.rank == 2:
            state_dict = {'b': 3}  # this is not saved at all (common saved by rank 0 only)
        else:
            state_dict = {
                'a': ShardedTensor.from_rank_offsets(
                    'x', torch.ones((2,)) * Utils.rank, replica_id=Utils.rank - 3
                )
            }

        with TempNamedDir(tmp_path_dist_ckpt / 'test_empty_load', sync=True) as ckpt_dir:
            save(state_dict, ckpt_dir)
            torch.distributed.barrier()
            loaded_state_dict = load(state_dict, ckpt_dir)
            assert loaded_state_dict['common'] == 'common-value'

            if Utils.rank <= 2:
                assert loaded_state_dict.keys() == {'common'}
            else:
                assert loaded_state_dict.keys() == {'common', 'a'}
                loaded_state_dict['a'].cpu().numpy().tolist() == [
                    3,
                    3,
                ]  # rank 3 held the main replica so did the saving

        Utils.destroy_model_parallel()


class TestNonStrictLoad:
    def setup_method(self, method):
        Utils.initialize_model_parallel(2, 4)  # doesn't matter for this test

    def teardown_method(self, method):
        Utils.destroy_model_parallel()

    def _get_base_state_dict(self):
        return {
            'TenA': ShardedTensor.from_rank_offsets('TenA', torch.arange(2), replica_id=Utils.rank),
            'TenB': ShardedTensor.from_rank_offsets(
                'TenB', torch.arange(3), (0, Utils.rank, Utils.world_size), replica_id=0
            ),
            'TenC': ShardedTensor.from_rank_offsets(
                'TenC', torch.arange(3), replica_id=Utils.world_size - Utils.rank - 1
            ),
            'ObjA': ShardedObject('ObjA', list(range(10)), (1,), (0,), replica_id=Utils.rank),
            'ObjB': ShardedObject(
                'ObjB', {Utils.rank + 7}, (1, Utils.world_size), (0, Utils.rank), replica_id=0
            ),
        }

    @pytest.mark.parametrize('save_format', ['zarr', 'torch_dist'])
    @pytest.mark.parametrize('validate_integrity', [True, False])
    def test_unexpected_keys_handling_during_validation(
        self, caplog, tmp_path_dist_ckpt, validate_integrity, save_format
    ):
        sharded_state_dict = self._get_base_state_dict()
        with TempNamedDir(
            tmp_path_dist_ckpt / 'test_unexpected_keys_raises_error_during_validation'
        ) as ckpt_dir:
            save_strategy = get_default_strategy(StrategyAction.SAVE_SHARDED, save_format, 1)
            save(sharded_state_dict, ckpt_dir, save_strategy)

            def load_with_flag(strict):
                sharded_state_dict = self._get_base_state_dict()
                sharded_state_dict['TenD'] = ShardedTensor.from_rank_offsets(
                    'UnexpectedTenD', torch.arange(3), replica_id=Utils.rank
                )
                sharded_state_dict['ObjD'] = ShardedObject(
                    'UnexpectedObjD', None, (1,), (0,), replica_id=Utils.rank
                )
                return load(
                    sharded_state_dict,
                    ckpt_dir,
                    validate_access_integrity=validate_integrity,
                    strict=strict,
                )

            def test_error(error_msg):
                assert 'Unexpected keys' in error_msg
                assert 'UnexpectedTenD' in error_msg
                assert 'UnexpectedObjD' in error_msg
                assert 'Missing keys' not in error_msg

            # ASSUME_OK_UNEXPECTED results in an exception raised by the underlying strategy
            with pytest.raises(
                PyTCheckpointingException if save_format == 'torch_dist' else CheckpointingException
            ) as exc_info:
                load_with_flag(StrictHandling.ASSUME_OK_UNEXPECTED)
            # Informative exceptions with `RAISE_*` options:
            with pytest.raises(CheckpointingException) as exc_info:
                load_with_flag(StrictHandling.RAISE_UNEXPECTED)
            test_error(str(exc_info.value))
            with pytest.raises(CheckpointingException) as exc_info:
                load_with_flag(StrictHandling.RAISE_ALL)
            test_error(str(exc_info.value))

            # Logged mismatches:
            with caplog.at_level(logging.WARNING):
                loaded_state_dict = load_with_flag(StrictHandling.LOG_UNEXPECTED)
            assert 'TenA' in loaded_state_dict
            test_error(caplog.text)
            with caplog.at_level(logging.WARNING):
                loaded_state_dict = load_with_flag(StrictHandling.LOG_ALL)
            assert 'TenA' in loaded_state_dict
            test_error(caplog.text)

            # Returned mismatches
            loaded_state_dict, missing_keys, unexpected_keys = load_with_flag(
                StrictHandling.RETURN_UNEXPECTED
            )
            assert 'TenA' in loaded_state_dict
            assert unexpected_keys == {'UnexpectedTenD', 'UnexpectedObjD'}
            assert missing_keys == set()
            loaded_state_dict, missing_keys, unexpected_keys = load_with_flag(
                StrictHandling.RETURN_ALL
            )
            assert 'TenA' in loaded_state_dict
            assert unexpected_keys == {'UnexpectedTenD', 'UnexpectedObjD'}
            assert missing_keys == set()

            # Ignore mismatch
            loaded_state_dict = load_with_flag(StrictHandling.IGNORE_ALL)
            assert 'TenA' in loaded_state_dict

    @pytest.mark.parametrize('save_format', ['zarr', 'torch_dist'])
    @pytest.mark.parametrize('validate_integrity', [True, False])
    def test_missing_keys_raises_error_during_validation(
        self, caplog, tmp_path_dist_ckpt, validate_integrity, save_format
    ):
        sharded_state_dict = self._get_base_state_dict()
        with TempNamedDir(
            tmp_path_dist_ckpt / 'test_missing_keys_raises_error_during_validation'
        ) as ckpt_dir:
            save_strategy = get_default_strategy(StrategyAction.SAVE_SHARDED, save_format, 1)
            save(sharded_state_dict, ckpt_dir, save_strategy)

            def load_with_flag(strict):
                sharded_state_dict = self._get_base_state_dict()
                del sharded_state_dict['TenA']
                del sharded_state_dict['ObjB']
                return load(
                    sharded_state_dict,
                    ckpt_dir,
                    validate_access_integrity=validate_integrity,
                    strict=strict,
                )

            def test_error(error_msg):
                assert 'Unexpected keys' not in error_msg
                assert 'TenA' in error_msg
                assert 'ObjB' in error_msg
                assert 'Missing keys' in error_msg

            # no mismatch for `*_UNEXPECTED` flag
            loaded_state_dict = load_with_flag(StrictHandling.ASSUME_OK_UNEXPECTED)
            assert 'TenB' in loaded_state_dict

            loaded_state_dict = load_with_flag(StrictHandling.RAISE_UNEXPECTED)
            assert 'TenB' in loaded_state_dict

            with caplog.at_level(logging.WARNING):
                loaded_state_dict = load_with_flag(StrictHandling.LOG_UNEXPECTED)
            assert (
                caplog.text == ''
                or '`zarr` distributed checkpoint backend is deprecated' in caplog.text
            )
            assert 'TenB' in loaded_state_dict

            loaded_state_dict, missing_keys, unexpected_keys = load_with_flag(
                StrictHandling.RETURN_UNEXPECTED
            )
            assert 'TenB' in loaded_state_dict
            assert missing_keys == set()
            assert unexpected_keys == set()

            loaded_state_dict = load_with_flag(StrictHandling.IGNORE_ALL)
            assert 'TenB' in loaded_state_dict

            # Informative exceptions with `RAISE_ALL` option:
            with pytest.raises(CheckpointingException) as exc_info:
                load_with_flag(StrictHandling.RAISE_ALL)
            test_error(str(exc_info.value))

            # Logged mismatches:
            with caplog.at_level(logging.WARNING):
                loaded_state_dict = load_with_flag(StrictHandling.LOG_ALL)
            assert 'TenB' in loaded_state_dict
            test_error(caplog.text)

            # Returned mismatches
            loaded_state_dict, missing_keys, unexpected_keys = load_with_flag(
                StrictHandling.RETURN_ALL
            )
            assert 'TenB' in loaded_state_dict
            assert unexpected_keys == set()
            assert missing_keys == {'TenA', 'ObjB'}

    @pytest.mark.parametrize('save_format', ['zarr', 'torch_dist'])
    @pytest.mark.parametrize('validate_integrity', [True, False])
    def test_exact_load_handling(self, caplog, tmp_path_dist_ckpt, validate_integrity, save_format):
        sharded_state_dict = self._get_base_state_dict()
        with TempNamedDir(tmp_path_dist_ckpt / 'test_exact_load_handling') as ckpt_dir:
            save_strategy = get_default_strategy(StrategyAction.SAVE_SHARDED, save_format, 1)
            save(sharded_state_dict, ckpt_dir, save_strategy)

            def load_with_flag(strict):
                sharded_state_dict = self._get_base_state_dict()
                return load(
                    sharded_state_dict,
                    ckpt_dir,
                    validate_access_integrity=validate_integrity,
                    strict=strict,
                )

            for strict in (
                StrictHandling.ASSUME_OK_UNEXPECTED,
                StrictHandling.LOG_UNEXPECTED,
                StrictHandling.LOG_ALL,
                StrictHandling.RAISE_UNEXPECTED,
                StrictHandling.RAISE_ALL,
                StrictHandling.IGNORE_ALL,
            ):
                with caplog.at_level(logging.WARNING):
                    loaded_state_dict = load_with_flag(strict)
                assert (
                    caplog.text == ''
                    or '`zarr` distributed checkpoint backend is deprecated' in caplog.text
                )
                assert 'TenB' in loaded_state_dict
                assert 'ObjB' in loaded_state_dict

            for strict in (StrictHandling.RETURN_UNEXPECTED, StrictHandling.RETURN_ALL):
                with caplog.at_level(logging.WARNING):
                    loaded_state_dict, missing_keys, unexpected_keys = load_with_flag(strict)
                assert (
                    caplog.text == ''
                    or '`zarr` distributed checkpoint backend is deprecated' in caplog.text
                )
                assert 'TenB' in loaded_state_dict
                assert 'ObjB' in loaded_state_dict
                assert missing_keys == set()
                assert unexpected_keys == set()

    @pytest.mark.parametrize('save_format', ['zarr', 'torch_dist'])
    def test_sharded_metadata(self, tmp_path_dist_ckpt, save_format):

        sharded_state_dict = self._get_base_state_dict()
        with TempNamedDir(tmp_path_dist_ckpt / 'test_exact_load_handling') as ckpt_dir:
            save_strategy = get_default_strategy(StrategyAction.SAVE_SHARDED, save_format, 1)
            save(sharded_state_dict, ckpt_dir, save_strategy)
            torch.distributed.barrier()
            sharded_metadata = load_sharded_metadata(ckpt_dir)
            assert set(sh_base.key for sh_base in sharded_metadata.values()) == {
                'TenA',
                'TenB',
                'TenC',
                'ObjA',
                'ObjB',
            }
            assert set(sharded_metadata.keys()) == {
                'TenA',
                'TenB',
                'TenC',
                'ObjA/shard_0_1',
                *(f'ObjB/shard_0.{i}_1.8' for i in range(8)),
            }

            loaded_state_dict = load(sharded_metadata, ckpt_dir, validate_access_integrity=False)

            assert loaded_state_dict['ObjA/shard_0_1'] == list(range(10))
            for shard_idx in range(8):
                assert loaded_state_dict[f'ObjB/shard_0.{shard_idx}_1.8'] == {shard_idx + 7}
            assert torch.all(loaded_state_dict['TenA'] == torch.arange(2))
            assert torch.all(loaded_state_dict['TenB'] == torch.arange(3).repeat(8))
            assert torch.all(loaded_state_dict['TenC'] == torch.arange(3))