Commit c688c175 authored by Anthony Chen's avatar Anthony Chen Committed by Facebook GitHub Bot
Browse files

add option to load checkpoints to GPU

Summary:
X-link: https://github.com/facebookresearch/detectron2/pull/4667

X-link: https://github.com/fairinternal/detectron2/pull/578

Pull Request resolved: https://github.com/facebookresearch/d2go/pull/411

Add config option `cfg.LOAD_CKPT_TO_GPU` to load checkpoints to the worker's current GPU

Previously, D2 (https://github.com/facebookresearch/d2go/commit/87374efb134e539090e0b5c476809dc35bf6aedb)go maps checkpoints to CPU before loading them to the model. In large-scale distributed training, many GPU processes may be used to train a model. This means each process will load the model checkpoint to a single CPU, causing the same model checkpoint to be loaded many times. This would cause CPU OOM issue when the model checkpoint size is large.

There're two solutions to this problem. One is to load checkpoints to GPU; the other one is to use share memory for the checkpoint between different GPU processes. This diff implements the first solution, which can support cases where model size + model checkpoint size is smaller than the total GPU memory. The second solution may be revisited for large models that need to offload checkpoints to cpu. Reference diff: D40789062

Reviewed By: mcimpoi

Differential Revision: D41063306

fbshipit-source-id: edcfd390a25582fffb2f1a6a7fc22917874ee2fc
parent 61e5ddce
...@@ -46,6 +46,23 @@ class QATCheckpointer(DetectionCheckpointer): ...@@ -46,6 +46,23 @@ class QATCheckpointer(DetectionCheckpointer):
(QAT / non-QAT) model. (QAT / non-QAT) model.
""" """
def __init__(
self,
model,
save_dir="",
*,
load_ckpt_to_gpu=False,
save_to_disk=None,
**checkpointables,
):
super().__init__(
model,
save_dir,
save_to_disk=save_to_disk,
**checkpointables,
)
self.load_ckpt_to_gpu = load_ckpt_to_gpu
@classmethod @classmethod
def _is_q_state_dict(cls, state_dict): def _is_q_state_dict(cls, state_dict):
return any(_is_observer_key(k) for k in state_dict) return any(_is_observer_key(k) for k in state_dict)
...@@ -56,7 +73,7 @@ class QATCheckpointer(DetectionCheckpointer): ...@@ -56,7 +73,7 @@ class QATCheckpointer(DetectionCheckpointer):
if filename.endswith(".ckpt"): if filename.endswith(".ckpt"):
# assume file is from lightning; no one else seems to use the ".ckpt" extension # assume file is from lightning; no one else seems to use the ".ckpt" extension
with PathManager.open(filename, "rb") as f: with PathManager.open(filename, "rb") as f:
data = torch.load(f, map_location=torch.device("cpu")) data = self._torch_load(f)
from d2go.runner.lightning_task import _convert_to_d2 from d2go.runner.lightning_task import _convert_to_d2
_convert_to_d2(data) _convert_to_d2(data)
...@@ -64,6 +81,14 @@ class QATCheckpointer(DetectionCheckpointer): ...@@ -64,6 +81,14 @@ class QATCheckpointer(DetectionCheckpointer):
return super()._load_file(filename) return super()._load_file(filename)
def _torch_load(self, f):
device = (
"cuda:{}".format(torch.cuda.current_device())
if self.load_ckpt_to_gpu
else "cpu"
)
return torch.load(f, map_location=torch.device(device))
def _load_model(self, checkpoint): def _load_model(self, checkpoint):
model_is_qat = self._is_q_state_dict(self.model.state_dict()) model_is_qat = self._is_q_state_dict(self.model.state_dict())
checkpoint_is_qat = self._is_q_state_dict(checkpoint["model"]) checkpoint_is_qat = self._is_q_state_dict(checkpoint["model"])
......
...@@ -100,6 +100,9 @@ def _add_detectron2go_runner_default_cfg(_C: CN) -> None: ...@@ -100,6 +100,9 @@ def _add_detectron2go_runner_default_cfg(_C: CN) -> None:
# Profiler # Profiler
_C.PROFILERS = ["default_flop_counter"] _C.PROFILERS = ["default_flop_counter"]
# Checkpointing-specific config
_C.LOAD_CKPT_TO_GPU = False
# Add FB specific configs # Add FB specific configs
_add_detectron2go_runner_default_fb_cfg(_C) _add_detectron2go_runner_default_fb_cfg(_C)
......
...@@ -482,6 +482,7 @@ class Detectron2GoRunner(D2GoDataAPIMixIn, BaseRunner): ...@@ -482,6 +482,7 @@ class Detectron2GoRunner(D2GoDataAPIMixIn, BaseRunner):
cfg, cfg,
model, model,
save_dir=cfg.OUTPUT_DIR, save_dir=cfg.OUTPUT_DIR,
load_ckpt_to_gpu=cfg.LOAD_CKPT_TO_GPU,
optimizer=optimizer, optimizer=optimizer,
scheduler=scheduler, scheduler=scheduler,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment