"git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "a00c735d9f69555331c463c564a7a40d0bed73b5"
Commit 5d6bd7c2 authored by Anthony Chen's avatar Anthony Chen Committed by Facebook GitHub Bot
Browse files

interleave FSDP checkpointing to avoid manifold quota exceeding

Summary:
X-link: https://github.com/facebookresearch/mobile-vision/pull/138

Pull Request resolved: https://github.com/facebookresearch/d2go/pull/477

Interleave FSDP checkpointing to avoid excessive reading/writing patterns that may cause manifold quota exceeding error

Reviewed By: wat3rBro

Differential Revision: D43266742

fbshipit-source-id: 85549c3b10413e0ffad2f3ec8e198d8c77486478
parent 99709e93
...@@ -5,10 +5,11 @@ from typing import cast, IO ...@@ -5,10 +5,11 @@ from typing import cast, IO
import detectron2.utils.comm as comm import detectron2.utils.comm as comm
import torch import torch
from d2go.modeling.ema import EMAState from d2go.modeling.ema import EMAState
from d2go.quantization.modeling import QATCheckpointer from d2go.quantization.modeling import QATCheckpointer
from d2go.trainer.fsdp import FSDPWrapper from d2go.trainer.fsdp import FSDPWrapper
from mobile_cv.torch.utils_pytorch.distributed_helper import interleave_by_rank
from torch.distributed.fsdp.fully_sharded_data_parallel import ( from torch.distributed.fsdp.fully_sharded_data_parallel import (
FullyShardedDataParallel as FSDP, FullyShardedDataParallel as FSDP,
) )
...@@ -145,9 +146,14 @@ class FSDPCheckpointer(QATCheckpointer): ...@@ -145,9 +146,14 @@ class FSDPCheckpointer(QATCheckpointer):
self.tag_last_checkpoint(basename) self.tag_last_checkpoint(basename)
def _save_file(self, data, filename): def _save_file(self, data, filename):
self.logger.info("Saving checkpoint to {}".format(filename)) with interleave_by_rank():
with self.path_manager.open(filename, "wb") as f: self.logger.info("Saving checkpoint to {}".format(filename))
torch.save(data, cast(IO[bytes], f)) with self.path_manager.open(filename, "wb") as f:
torch.save(data, cast(IO[bytes], f))
def _load_file(self, f: str):
with interleave_by_rank():
return super()._load_file(f)
def gather_optimizer_state_dict(optimizer, model: FSDPWrapper): def gather_optimizer_state_dict(optimizer, model: FSDPWrapper):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment