"configs/datasets/SVAMP/svamp_gen_fb25e4.py" did not exist on "8c1483e3ce08773e43e7d57bb5eb6f6aa2fcdb08"
test_mapping.py 5.36 KB
Newer Older
hepj's avatar
hepj committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.

import pytest

import torch

from megatron.core.dist_checkpointing import ShardedTensor
from megatron.core.dist_checkpointing.core import CheckpointingException
from megatron.core.dist_checkpointing.mapping import is_main_replica, \
    ShardedTensorFactory, ShardedObject, apply_factories, apply_factory_merges
from megatron.core.transformer.transformer_config import TransformerConfig
from tests.unit_tests.test_utilities import Utils

class TestShardedTensor:

    # def setup_method(self, method):
    #     Utils.initialize_model_parallel(1,1)
    #     transformer_config = TransformerConfig(num_layers=2, hidden_size=12, num_attention_heads=4, use_cpu_initialization=True)
    #     self.gpt_embedding = GPTEmbedding(config=transformer_config, vocab_size=100, max_sequence_length=4, add_position_embedding=True)
    #
    # def teardown_method(self, method):
    #     Utils.destroy_model_parallel()
    
    def test_from_rank_offsets_constructor(self, dtype=torch.float, device='cuda'):
        data = torch.ones((1, 3, 7, 9), dtype=dtype, device=device)
        shape = data.shape
        rank_offsets = [
            (0, 0, 10),
            (2, 3, 6)
        ]
        sh_ten = ShardedTensor.from_rank_offsets('keyA', data, *rank_offsets)

        assert isinstance(sh_ten, ShardedTensor)
        assert sh_ten.dtype is dtype
        assert sh_ten.local_shape == shape
        assert sh_ten.global_shape == (shape[0] * 10, shape[1], shape[2] * 6, shape[3])
        assert sh_ten.global_offset == (0, 0, shape[2] * 3, 0)
        assert sh_ten.axis_fragmentations == (10, 1, 6, 1)

    def test_from_rank_offsets_flat_constructor(self, dtype=torch.float, device='cuda'):
        data = torch.arange(28, dtype=dtype, device=device).reshape((1, 4, 7))
        shape = data.shape
        rank_offsets = [
            (1, 0, 2),
            (2, 3, 5)
        ]
        flattened_range = slice(4, 9)
        flat_data = data.flatten()[flattened_range]
        sh_ten = ShardedTensor.from_rank_offsets_flat('keyA', flat_data, data.shape, *rank_offsets, flattened_range=flattened_range)

        # The main attributes properties are unchanged
        assert isinstance(sh_ten, ShardedTensor)
        assert sh_ten.dtype is dtype
        assert sh_ten.local_shape == shape
        assert sh_ten.global_shape == (shape[0], shape[1] * 2, shape[2] * 5)
        assert sh_ten.global_offset == (0, 0, shape[2] * 3)
        assert sh_ten.axis_fragmentations == (1, 2, 5)

        assert torch.all(sh_ten.data == torch.arange(4, 9, device=device))

    def test_metadata_integrity_violation(self):
        data = torch.ones((1, 3, 7, 9), device='meta')
        rank_offsets = [
            (0, 0, 10),
            (2, 3, 6)
        ]
        sh_ten = ShardedTensor.from_rank_offsets('keyA', data, *rank_offsets)
        sh_ten.validate_metadata_integrity()
        with pytest.raises(CheckpointingException):
            sh_ten.local_shape = (1, 2, 7, 9)
            sh_ten.validate_metadata_integrity()

        sh_ten = ShardedTensor.from_rank_offsets('keyA', data, *rank_offsets)
        with pytest.raises(CheckpointingException):
            sh_ten.global_offset = (0, 1, 0)
            sh_ten.validate_metadata_integrity()

        with pytest.raises(CheckpointingException):
            sh_ten = ShardedTensor.from_rank_offsets_flat('keyA', data, data.shape, *rank_offsets,
                                                          flattened_range=slice(4, 9))

        sh_ten = ShardedTensor.from_rank_offsets_flat('keyA', data.flatten()[4:9], data.shape, *rank_offsets,
                                                      flattened_range=slice(4, 9))
        assert sh_ten.local_shape == (1, 3, 7, 9)
        with pytest.raises(CheckpointingException):
            sh_ten.local_shape = (5,)
            sh_ten.validate_metadata_integrity()



class TestShardedTensorFactory:
    def test_build_and_merge(self):
        def build_fn(key, tensor, replica_id, flattened_range):
            assert flattened_range is None
            return {
                'level2_a': ShardedTensor.from_rank_offsets(key + 'part1', tensor + 1, replica_id=replica_id),
                'level2_b': ShardedTensor.from_rank_offsets(key + 'part2', tensor + 2, replica_id=replica_id)
            }

        # state_dict will be modified in-place
        def get_state_dict():
            return {
                'level1': ShardedTensorFactory('a', torch.arange(3), build_fn, lambda x: x['level2_b'])
            }
        state_dict = get_state_dict()
        apply_factories(state_dict)
        assert torch.allclose(state_dict['level1']['level2_a'].data, torch.tensor([1, 2, 3]))
        assert torch.allclose(state_dict['level1']['level2_b'].data, torch.tensor([2, 3, 4]))

        # Simulate loading
        state_dict['level1']['level2_a'] = state_dict['level1']['level2_a'].data
        state_dict['level1']['level2_b'] = state_dict['level1']['level2_b'].data

        loaded_state_dict = apply_factory_merges(state_dict, get_state_dict())
        assert torch.allclose(loaded_state_dict['level1'], torch.tensor([2, 3, 4]))


def test_is_main_replica():
    assert is_main_replica(0)
    assert is_main_replica((0,))
    assert is_main_replica((0, 0))
    assert not is_main_replica(1)
    assert not is_main_replica(2)
    assert not is_main_replica((1,))
    assert not is_main_replica((1, 0))
    assert not is_main_replica((1, 1, 1))