Commit 5d6bd7c2 authored by Anthony Chen's avatar Anthony Chen Committed by Facebook GitHub Bot
Browse files

interleave FSDP checkpointing to avoid manifold quota exceeding

Summary:
X-link: https://github.com/facebookresearch/mobile-vision/pull/138

Pull Request resolved: https://github.com/facebookresearch/d2go/pull/477

Interleave FSDP checkpointing to avoid excessive reading/writing patterns that may cause manifold quota exceeding error

Reviewed By: wat3rBro

Differential Revision: D43266742

fbshipit-source-id: 85549c3b10413e0ffad2f3ec8e198d8c77486478
parent 99709e93
......@@ -5,10 +5,11 @@ from typing import cast, IO
import detectron2.utils.comm as comm
import torch
from d2go.modeling.ema import EMAState
from d2go.quantization.modeling import QATCheckpointer
from d2go.trainer.fsdp import FSDPWrapper
from mobile_cv.torch.utils_pytorch.distributed_helper import interleave_by_rank
from torch.distributed.fsdp.fully_sharded_data_parallel import (
FullyShardedDataParallel as FSDP,
)
......@@ -145,9 +146,14 @@ class FSDPCheckpointer(QATCheckpointer):
self.tag_last_checkpoint(basename)
def _save_file(self, data, filename):
self.logger.info("Saving checkpoint to {}".format(filename))
with self.path_manager.open(filename, "wb") as f:
torch.save(data, cast(IO[bytes], f))
with interleave_by_rank():
self.logger.info("Saving checkpoint to {}".format(filename))
with self.path_manager.open(filename, "wb") as f:
torch.save(data, cast(IO[bytes], f))
def _load_file(self, f: str):
with interleave_by_rank():
return super()._load_file(f)
def gather_optimizer_state_dict(optimizer, model: FSDPWrapper):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment