# Copyright 2024 Bytedance Ltd. and/or its affiliates # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os from typing import Optional, Union import torch import torch.distributed as dist from torch.distributed.checkpoint.state_dict import StateDictOptions, get_state_dict, set_state_dict from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin from .checkpoint_manager import BaseCheckpointManager class FSDPCheckpointManager(BaseCheckpointManager): """ A checkpoint manager that saves and loads - model - optimizer - lr_scheduler - extra_states in a SPMD way. We save - sharded model states and optimizer states - full lr_scheduler states - huggingface tokenizer and config for ckpt merge """ def __init__( self, model: FSDP, optimizer: torch.optim.Optimizer, lr_scheduler: torch.optim.lr_scheduler.LRScheduler, processing_class: Union[PreTrainedTokenizer, ProcessorMixin], ): super().__init__(model, optimizer, lr_scheduler, processing_class) def load_checkpoint(self, path: Optional[str] = None): if path is None: return # every rank download its own checkpoint model_path = os.path.join(path, f"model_world_size_{self.world_size}_rank_{self.rank}.pt") optim_path = os.path.join(path, f"optim_world_size_{self.world_size}_rank_{self.rank}.pt") extra_state_path = os.path.join(path, f"extra_state_world_size_{self.world_size}_rank_{self.rank}.pt") print(f"[rank-{self.rank}]: Loading from {model_path} and {optim_path} and {extra_state_path}.") model_state_dict = torch.load(model_path, weights_only=False) optim_state_dict = torch.load(optim_path, weights_only=False) extra_state_dict = torch.load(extra_state_path, weights_only=False) state_dict_options = StateDictOptions(cpu_offload=True) set_state_dict( model=self.model, optimizers=self.optimizer, model_state_dict=model_state_dict, optim_state_dict=optim_state_dict, options=state_dict_options, ) self.lr_scheduler.load_state_dict(extra_state_dict["lr_scheduler"]) # recover random state if "rng" in extra_state_dict: self.load_rng_state(extra_state_dict["rng"]) def save_checkpoint(self, path: str): path = self.local_mkdir(path) dist.barrier() # every rank will save its own model and optim shard state_dict_options = StateDictOptions(cpu_offload=True) model_state_dict, optim_state_dict = get_state_dict(self.model, self.optimizer, options=state_dict_options) extra_state_dict = { "lr_scheduler": self.lr_scheduler.state_dict(), "rng": self.get_rng_state(), } model_path = os.path.join(path, f"model_world_size_{self.world_size}_rank_{self.rank}.pt") optim_path = os.path.join(path, f"optim_world_size_{self.world_size}_rank_{self.rank}.pt") extra_path = os.path.join(path, f"extra_state_world_size_{self.world_size}_rank_{self.rank}.pt") print(f"[rank-{self.rank}]: Saving model to {os.path.abspath(model_path)}.") print(f"[rank-{self.rank}]: Saving checkpoint to {os.path.abspath(model_path)}.") print(f"[rank-{self.rank}]: Saving extra_state to {os.path.abspath(extra_path)}.") torch.save(model_state_dict, model_path) torch.save(optim_state_dict, optim_path) torch.save(extra_state_dict, extra_path) # wait for everyone to dump to local dist.barrier() if self.rank == 0: hf_path = os.path.join(path, "huggingface") os.makedirs(hf_path, exist_ok=True) assert isinstance(self.model._fsdp_wrapped_module, PreTrainedModel) self.model._fsdp_wrapped_module.config.save_pretrained(hf_path) self.model._fsdp_wrapped_module.generation_config.save_pretrained(hf_path) self.processing_class.save_pretrained(hf_path) dist.barrier()