fsdp_checkpoint_manager.py 5.65 KB
Newer Older
chenych's avatar
chenych committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import warnings
chenych's avatar
chenych committed
17
from typing import Optional, Union
chenych's avatar
chenych committed
18
19

import torch
chenych's avatar
chenych committed
20
import torch.distributed as dist
chenych's avatar
chenych committed
21
22
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import ShardedOptimStateDictConfig, ShardedStateDictConfig, StateDictType
chenych's avatar
chenych committed
23
from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
chenych's avatar
chenych committed
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47

from .checkpoint_manager import BaseCheckpointManager


class FSDPCheckpointManager(BaseCheckpointManager):
    """
    A checkpoint manager that saves and loads
    - model
    - optimizer
    - lr_scheduler
    - extra_states
    in a SPMD way.

    We save
    - sharded model states and optimizer states
    - full lr_scheduler states
    - huggingface tokenizer and config for ckpt merge
    """

    def __init__(
        self,
        model: FSDP,
        optimizer: torch.optim.Optimizer,
        lr_scheduler: torch.optim.lr_scheduler.LRScheduler,
chenych's avatar
chenych committed
48
        processing_class: Union[PreTrainedTokenizer, ProcessorMixin],
chenych's avatar
chenych committed
49
    ):
chenych's avatar
chenych committed
50
        super().__init__(model, optimizer, lr_scheduler, processing_class)
chenych's avatar
chenych committed
51

chenych's avatar
chenych committed
52
    def load_checkpoint(self, path: Optional[str] = None):
chenych's avatar
chenych committed
53
54
55
56
        if path is None:
            return

        # every rank download its own checkpoint
chenych's avatar
chenych committed
57
58
59
60
61
62
63
        model_path = os.path.join(path, f"model_world_size_{self.world_size}_rank_{self.rank}.pt")
        optim_path = os.path.join(path, f"optim_world_size_{self.world_size}_rank_{self.rank}.pt")
        extra_state_path = os.path.join(path, f"extra_state_world_size_{self.world_size}_rank_{self.rank}.pt")
        print(f"[rank-{self.rank}]: Loading from {model_path} and {optim_path} and {extra_state_path}.")
        model_state_dict = torch.load(model_path, weights_only=False)
        optimizer_state_dict = torch.load(optim_path, weights_only=False)
        extra_state_dict = torch.load(extra_state_path, weights_only=False)
chenych's avatar
chenych committed
64
65
        lr_scheduler_state_dict = extra_state_dict["lr_scheduler"]

chenych's avatar
chenych committed
66
67
68
69
70
71
72
73
        state_dict_config = ShardedStateDictConfig(offload_to_cpu=True)
        optim_config = ShardedOptimStateDictConfig(offload_to_cpu=True)
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            with FSDP.state_dict_type(self.model, StateDictType.SHARDED_STATE_DICT, state_dict_config, optim_config):
                self.model.load_state_dict(model_state_dict)
                if self.optimizer is not None:
                    self.optimizer.load_state_dict(optimizer_state_dict)
chenych's avatar
chenych committed
74
75
76
77

        if self.lr_scheduler is not None:
            self.lr_scheduler.load_state_dict(lr_scheduler_state_dict)

chenych's avatar
chenych committed
78
79
80
        # recover random state
        if "rng" in extra_state_dict:
            self.load_rng_state(extra_state_dict["rng"])
chenych's avatar
chenych committed
81

chenych's avatar
chenych committed
82
83
84
    def save_checkpoint(self, path: str):
        path = self.local_mkdir(path)
        dist.barrier()
chenych's avatar
chenych committed
85
86

        # every rank will save its own model and optim shard
chenych's avatar
chenych committed
87
88
        state_dict_config = ShardedStateDictConfig(offload_to_cpu=True)
        optim_config = ShardedOptimStateDictConfig(offload_to_cpu=True)
chenych's avatar
chenych committed
89
90
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
chenych's avatar
chenych committed
91
            with FSDP.state_dict_type(self.model, StateDictType.SHARDED_STATE_DICT, state_dict_config, optim_config):
chenych's avatar
chenych committed
92
93
94
95
96
                model_state_dict = self.model.state_dict()
                if self.optimizer is not None:
                    optimizer_state_dict = self.optimizer.state_dict()
                else:
                    optimizer_state_dict = None
chenych's avatar
chenych committed
97

chenych's avatar
chenych committed
98
99
100
101
102
103
104
105
106
                if self.lr_scheduler is not None:
                    lr_scheduler_state_dict = self.lr_scheduler.state_dict()
                else:
                    lr_scheduler_state_dict = None

                extra_state_dict = {
                    "lr_scheduler": lr_scheduler_state_dict,
                    "rng": self.get_rng_state(),
                }
chenych's avatar
chenych committed
107
108
109
                model_path = os.path.join(path, f"model_world_size_{self.world_size}_rank_{self.rank}.pt")
                optim_path = os.path.join(path, f"optim_world_size_{self.world_size}_rank_{self.rank}.pt")
                extra_path = os.path.join(path, f"extra_state_world_size_{self.world_size}_rank_{self.rank}.pt")
chenych's avatar
chenych committed
110

chenych's avatar
chenych committed
111
112
113
                print(f"[rank-{self.rank}]: Saving model to {os.path.abspath(model_path)}.")
                print(f"[rank-{self.rank}]: Saving checkpoint to {os.path.abspath(model_path)}.")
                print(f"[rank-{self.rank}]: Saving extra_state to {os.path.abspath(extra_path)}.")
chenych's avatar
chenych committed
114
                torch.save(model_state_dict, model_path)
chenych's avatar
chenych committed
115
116
117
                if self.optimizer is not None:
                    torch.save(optimizer_state_dict, optim_path)

chenych's avatar
chenych committed
118
119
120
                torch.save(extra_state_dict, extra_path)

        # wait for everyone to dump to local
chenych's avatar
chenych committed
121
        dist.barrier()
chenych's avatar
chenych committed
122
123

        if self.rank == 0:
chenych's avatar
chenych committed
124
125
126
127
128
129
130
131
            hf_path = os.path.join(path, "huggingface")
            os.makedirs(hf_path, exist_ok=True)
            assert isinstance(self.model._fsdp_wrapped_module, PreTrainedModel)
            self.model._fsdp_wrapped_module.config.save_pretrained(hf_path)
            self.model._fsdp_wrapped_module.generation_config.save_pretrained(hf_path)
            self.processing_class.save_pretrained(hf_path)

        dist.barrier()