Commit 409cd213 authored by Francisc Bungiu's avatar Francisc Bungiu Committed by Facebook GitHub Bot
Browse files

Enable preemption checkpointing

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/639

Expose ability to add a preemption checkpointing hook running in a separate process group.

Reviewed By: wat3rBro, ynonaolga

Differential Revision: D51115437

fbshipit-source-id: c843802bc59da9f57c09c8d9a20f3d72d5b98edf
parent d0e16684
...@@ -6,11 +6,13 @@ import contextlib ...@@ -6,11 +6,13 @@ import contextlib
import logging import logging
import os import os
from collections import OrderedDict from collections import OrderedDict
from datetime import timedelta
from functools import lru_cache from functools import lru_cache
from typing import Any, List, Optional, Type, Union from typing import Any, List, Optional, Type, Union
import detectron2.utils.comm as comm import detectron2.utils.comm as comm
import torch import torch
import torch.distributed as dist
from d2go.checkpoint.api import is_distributed_checkpoint from d2go.checkpoint.api import is_distributed_checkpoint
from d2go.checkpoint.fsdp_checkpoint import FSDPCheckpointer from d2go.checkpoint.fsdp_checkpoint import FSDPCheckpointer
from d2go.config import CfgNode, CONFIG_SCALING_METHOD_REGISTRY, temp_defrost from d2go.config import CfgNode, CONFIG_SCALING_METHOD_REGISTRY, temp_defrost
...@@ -78,6 +80,7 @@ logger = logging.getLogger(__name__) ...@@ -78,6 +80,7 @@ logger = logging.getLogger(__name__)
ALL_TB_WRITERS = [] ALL_TB_WRITERS = []
CONTROL_PG_TIMEOUT = timedelta(minutes=30)
@lru_cache() @lru_cache()
...@@ -156,10 +159,28 @@ def get_monitoring_service() -> Any: ...@@ -156,10 +159,28 @@ def get_monitoring_service() -> Any:
return contextlib.nullcontext() return contextlib.nullcontext()
@fb_overwritable()
def create_preemption_hook(
cfg: CfgNode,
periodic_checkpointer: PeriodicCheckpointer,
process_group: Optional[dist.ProcessGroup],
) -> Any:
return None
class BaseRunner(object): class BaseRunner(object):
def __init__(self): def __init__(self):
identifier = f"D2Go.Runner.{self.__class__.__name__}" identifier = f"D2Go.Runner.{self.__class__.__name__}"
torch._C._log_api_usage_once(identifier) torch._C._log_api_usage_once(identifier)
# initialize the control pg for stuff like checkpoint and preemption handling
logger.info("Initializing control pg")
self._control_pg: Optional[dist.ProcessGroup] = None
if dist.is_initialized():
logger.info("Create gloo CPU control pg")
self._control_pg = dist.new_group(
backend=dist.Backend.GLOO,
timeout=CONTROL_PG_TIMEOUT,
)
def _initialize(self, cfg): def _initialize(self, cfg):
"""Runner should be initialized in the sub-process in ddp setting""" """Runner should be initialized in the sub-process in ddp setting"""
...@@ -520,6 +541,7 @@ class Detectron2GoRunner(D2GoDataAPIMixIn, BaseRunner): ...@@ -520,6 +541,7 @@ class Detectron2GoRunner(D2GoDataAPIMixIn, BaseRunner):
self._create_after_step_hook( self._create_after_step_hook(
cfg, model, optimizer, scheduler, periodic_checkpointer cfg, model, optimizer, scheduler, periodic_checkpointer
), ),
create_preemption_hook(cfg, periodic_checkpointer, self._control_pg),
hooks.EvalHook( hooks.EvalHook(
cfg.TEST.EVAL_PERIOD, cfg.TEST.EVAL_PERIOD,
lambda: self.do_test(cfg, model, train_iter=trainer.iter), lambda: self.do_test(cfg, model, train_iter=trainer.iter),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment