serialization.py 12.3 KB
Newer Older
liangjing's avatar
v1  
liangjing committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
# Copyright (c) 2022-2023, NVIDIA CORPORATION.  All rights reserved.

import logging
import os
from collections import Counter, defaultdict
from itertools import chain
from pathlib import Path
from typing import Iterable, List, Tuple, Union

import numpy as np
import torch

from .core import CheckpointingConfig, maybe_load_config, save_config
from .dict_utils import (
    dict_list_map_inplace,
    diff,
    extract_matching_values,
    map_reduce,
    merge,
    nested_values,
)
from .mapping import (
    CheckpointingException,
    ShardedObject,
    ShardedStateDict,
    ShardedTensor,
    ShardedTensorFactory,
    StateDict,
    apply_factories,
    apply_factory_merges,
    is_main_replica,
)
from .strategies.base import (
    LoadCommonStrategy,
    LoadShardedStrategy,
    SaveCommonStrategy,
    SaveShardedStrategy,
    StrategyAction,
    get_default_strategy,
)
from .utils import extract_sharded_tensors, extract_sharded_tensors_or_nonpersistent

COMMON_STATE_FNAME = 'common.pt'

logger = logging.getLogger(__name__)


def load(
    sharded_state_dict: ShardedStateDict,
    checkpoint_dir: str,
    sharded_strategy: Union[LoadShardedStrategy, None] = None,
    common_strategy: Union[LoadCommonStrategy, None] = None,
) -> StateDict:
    """Loading entrypoint.

    Arguments:
        sharded_state_dict: state dict of the existing model populated with
            ShardedTensors. Used as a mapping to determine which parts of
            global tensors stored in the checkpoint should be loaded.
        checkpoint_dir: directory with the checkpoint
        sharded_strategy: configures loading behavior for sharded tensors
        common_strategy: configures loading behavior for common data
    """
    if common_strategy is not None:
        raise NotImplementedError('The only supported common strategy is torch')

    checkpoint_dir = Path(checkpoint_dir)
    common_state_dict = load_common_state_dict(checkpoint_dir)
    if not sharded_state_dict:
        return common_state_dict

    sharded_objects, sharded_state_dict = load_sharded_objects(sharded_state_dict, checkpoint_dir)
    merge(common_state_dict, sharded_objects)

    saved_config = maybe_load_config(checkpoint_dir)
    if saved_config is None:
        raise CheckpointingException(f'{checkpoint_dir} is not a distributed checkpoint')

    sh_ten_factories, _ = extract_matching_values(
        sharded_state_dict, lambda x: isinstance(x, ShardedTensorFactory)
    )
    apply_factories(sharded_state_dict)
    sharded_state_dict, _ = extract_sharded_tensors_or_nonpersistent(sharded_state_dict)
    sharded_state_dict, nonpersistent_state_dict = extract_sharded_tensors(sharded_state_dict)
    dict_list_map_inplace(lambda o: o.unwrap(), nonpersistent_state_dict)
    merge(common_state_dict, nonpersistent_state_dict)

    validate_sharding_integrity(nested_values(sharded_state_dict))

    if sharded_strategy is None:
        sharded_strategy = get_default_strategy(
            StrategyAction.LOAD_SHARDED,
            saved_config.sharded_backend,
            saved_config.sharded_backend_version,
        )
    else:
        # TODO: implement consistency checks here
        pass
    loaded_state_dict = sharded_strategy.load(sharded_state_dict, checkpoint_dir)

    loaded_state_dict = apply_factory_merges(loaded_state_dict, sh_ten_factories)

    merge(common_state_dict, loaded_state_dict)
    return common_state_dict


# TODO: implement it as common torch strategy
def load_common_state_dict(checkpoint_dir: Path):
    return torch.load(Path(checkpoint_dir) / COMMON_STATE_FNAME, map_location='cpu')


def load_sharded_objects(sharded_state_dict: ShardedStateDict, checkpoint_dir: Path):
    sharded_objects, sharded_state_dict = extract_matching_values(
        sharded_state_dict, lambda v: isinstance(v, ShardedObject)
    )

    def load_sharded_object(sh_obj: ShardedObject):
        sh_obj.data = None
        load_path = (checkpoint_dir / sh_obj.unique_key).with_suffix('.pt')
        loaded_obj = torch.load(load_path)
        return loaded_obj

    return dict_list_map_inplace(load_sharded_object, sharded_objects), sharded_state_dict


def save(
    sharded_state_dict: ShardedStateDict,
    checkpoint_dir: str,
    sharded_strategy: Union[SaveShardedStrategy, None] = None,
    common_strategy: Union[SaveCommonStrategy, None] = None,
):
    """Saving entrypoint.

    Extracts ShardedTensors from the given state dict. Rank 0 saves the
    "regular" part of the checkpoint to common torch file.
    The ShardedTensors are saved according to a strategy specified by the
    config.

    Arguments:
        sharded_state_dict: state dict of the populated with
            ShardedTensors. Used as a mapping to determine how local tensors
            should be saved as global tensors in the checkpoint.
        checkpoint_dir: directory to save the checkpoint to
        sharded_strategy: configures sharded tensors saving behavior and backend
        common_strategy: configures common data saving behavior and backend
    """
    checkpoint_dir = Path(checkpoint_dir)

    if torch.distributed.get_rank() == 0:
        if not checkpoint_dir.exists():
            raise CheckpointingException(
                f'Checkpoint destination directory does not exist: {checkpoint_dir}'
            )

        if next(checkpoint_dir.iterdir(), None) is not None:
            raise CheckpointingException(
                f'Checkpoint destination directory ({checkpoint_dir}) is not empty'
            )

    if common_strategy is not None:
        raise NotImplementedError('The only supported common strategy is torch')

    if sharded_strategy is None:
        sharded_strategy = get_default_strategy(StrategyAction.SAVE_SHARDED, 'zarr', 1)

    apply_factories(sharded_state_dict)
    sharded_state_dict, state_dict = extract_sharded_tensors_or_nonpersistent(sharded_state_dict)
    sharded_state_dict, _ = extract_sharded_tensors(sharded_state_dict)
    sharded_tensors = list(nested_values(sharded_state_dict))
    validate_sharding_integrity(sharded_tensors)

    _save_common_dict(state_dict, checkpoint_dir, True)

    sharded_strategy.save(sharded_tensors, checkpoint_dir)
    save_config(
        CheckpointingConfig(sharded_strategy.backend, sharded_strategy.version), checkpoint_dir
    )


# TODO: implement it as common torch strategy
def _save_common_dict(
    state_dict: StateDict, checkpoint_dir: Path, validate_consistency: bool = False
):
    common_state_dict = _extract_and_save_sharded_objects(
        state_dict, checkpoint_dir, validate_consistency
    )
    if torch.distributed.get_rank() == 0:
        torch.save(common_state_dict, checkpoint_dir / COMMON_STATE_FNAME)
    if validate_consistency:
        # TODO: implement checking consistency with rank 0 common dict on other ranks
        pass
        # torch.distributed.barrier()
        # if not torch.distributed.get_rank() == 0:
        #     rank_0_state_dict = torch.load(checkpoint_dir / COMMON_STATE_FNAME)
        #     print(diff(common_state_dict, rank_0_state_dict))


def _extract_and_save_sharded_objects(
    state_dict: StateDict, checkpoint_dir: Path, validate_consistency: bool = False
):
    sharded_objects, state_dict = extract_matching_values(
        state_dict, lambda v: isinstance(v, ShardedObject)
    )
    sharded_objects = list(nested_values(sharded_objects))
    if validate_consistency:
        validate_objects_sharding_integrity(sharded_objects)
    for sh_obj in sharded_objects:
        if is_main_replica(sh_obj.replica_id):
            save_path = (checkpoint_dir / sh_obj.unique_key).with_suffix('.pt')
            os.makedirs(save_path.parent, exist_ok=True)
            torch.save(sh_obj.data, save_path)
    return state_dict


def validate_sharding_integrity(sharded_tensors: Iterable[ShardedTensor]):
    sharding = [ten.without_data() for ten in sharded_tensors]
    all_sharding = [None] * torch.distributed.get_world_size()
    torch.distributed.all_gather_object(all_sharding, sharding)
    if torch.distributed.get_rank() != 0:
        return

    key_shardings = defaultdict(list)
    for rank, rank_shardings in enumerate(all_sharding):
        for sharding in rank_shardings:
            key_shardings[sharding.key].append((rank, sharding))
    for key, shardings in key_shardings.items():
        _validate_sharding_for_key(shardings)


def _validate_sharding_for_key(rank_sharding: List[Tuple[int, ShardedTensor]]):
    some_rank_shard = rank_sharding[0][1]
    global_shape = some_rank_shard.global_shape
    local_shape = some_rank_shard.local_shape
    dtype = some_rank_shard.dtype
    has_flattened_range = some_rank_shard.flattened_range is not None
    for rank, sharding in rank_sharding:
        assert sharding.dtype == dtype, (sharding.dtype, dtype, some_rank_shard)
        assert sharding.global_shape == global_shape, (
            sharding.global_shape,
            global_shape,
            some_rank_shard,
        )
        assert sharding.local_shape == local_shape, (
            sharding.local_shape,
            local_shape,
            some_rank_shard,
        )
        assert (sharding.flattened_range is not None) == has_flattened_range, (
            (sharding.flattened_range is not None),
            has_flattened_range,
            some_rank_shard,
        )

    shard_access_cnt = _compute_shards_access(rank_sharding)
    if has_flattened_range:
        map_reduce(
            rank_sharding,
            lambda x: x[1].global_offset,
            lambda x: x[1],
            _validate_sharding_for_key_flattened,
        )
    else:
        if not torch.all(shard_access_cnt == 1):
            logger.error(f'Invalid access pattern for {rank_sharding[0][1]}: {shard_access_cnt}')
            raise CheckpointingException(f'Invalid access pattern for {rank_sharding[0][1]}')


def _compute_shards_access(rank_sharding):
    def chunk_offset(sharding):
        assert len(sharding.global_offset) == len(sharding.local_shape) + sharding.prepend_axis_num
        return tuple(
            chain(
                (off for off in sharding.global_offset[: sharding.prepend_axis_num]),
                (
                    off // sh
                    for off, sh in zip(
                        sharding.global_offset[sharding.prepend_axis_num :], sharding.local_shape
                    )
                ),
            )
        )

    shard_access_cnt = torch.zeros(
        rank_sharding[0][1].axis_fragmentations, dtype=torch.int, device='cpu'
    )
    for rank, sharding in rank_sharding:
        if is_main_replica(sharding.replica_id):
            shard_access_cnt[chunk_offset(sharding)] += 1
        # TODO: consider validating different replicas too
    return shard_access_cnt


def _validate_sharding_for_key_flattened(tensors_by_shard):
    all_slices = []
    local_shape = tensors_by_shard[0].local_shape
    for sharding in tensors_by_shard:
        assert sharding.local_shape == local_shape
        sharding: ShardedTensor
        if not is_main_replica(sharding.replica_id):
            # TODO: this checks only saving (and loading replica_id=0) consistency
            continue

        all_slices.append((sharding.flattened_range.start, sharding.flattened_range.stop))

    starts, stops = map(np.asarray, zip(*sorted(all_slices)))
    if (
        starts[0] != 0
        or stops[-1] != np.product(local_shape)
        or not np.all(starts[1:] == stops[:-1])
    ):
        logger.error(
            f'Flattened ranges dont cover the whole shard {tensors_by_shard[0]}. Ranges: {(starts, stops)}'
        )
        raise CheckpointingException(
            f'Flattened ranges dont cover the whole shard {tensors_by_shard[0]}'
        )


def validate_objects_sharding_integrity(sharded_objects: List[ShardedObject]):
    """ Ensure uniqueness of saved objects. """
    local_sh_objs = [sh_obj.without_data() for sh_obj in sharded_objects]
    all_sh_objs = [None] * torch.distributed.get_world_size()
    torch.distributed.all_gather_object(all_sh_objs, local_sh_objs)
    if torch.distributed.get_rank() != 0:
        return
    unique_keys = [
        sh_obj.unique_key
        for sh_obj in chain.from_iterable(all_sh_objs)
        if is_main_replica(sh_obj.replica_id)
    ]
    if len(unique_keys) != len(set(unique_keys)):
        duplicates = {k: cnt for k, cnt in Counter(unique_keys).items() if cnt > 1}
        logger.error(f'Duplicate ShardedObject keys and counts: {duplicates}')
        raise CheckpointingException(f'Duplicate ShardedObject keys: {list(duplicates.keys())}')