checkpoint_engine.py 654 Bytes
Newer Older
aiss's avatar
aiss committed
1
2
3
4
5
6
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

import os
aiss's avatar
aiss committed
7
8
9
10
11
12
13
14
15
16
17
18


class CheckpointEngine(object):

    # init checkpoint engine for save/load
    def __init__(self, config_params=None):
        pass

    def create(self, tag):
        # create checkpoint on give tag for save/load.
        pass

aiss's avatar
aiss committed
19
20
21
    def makedirs(self, path, exist_ok=False):
        os.makedirs(path, exist_ok=exist_ok)

aiss's avatar
aiss committed
22
23
24
25
26
27
28
29
30
    def save(self, state_dict, path: str):
        pass

    def load(self, path: str, map_location=None):
        pass

    def commit(self, tag):
        # to tell checkpoint services if all files are readys.
        pass