Commit 409cd213 authored by Francisc Bungiu's avatar Francisc Bungiu Committed by Facebook GitHub Bot
Browse files

Enable preemption checkpointing

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/639

Expose ability to add a preemption checkpointing hook running in a separate process group.

Reviewed By: wat3rBro, ynonaolga

Differential Revision: D51115437

fbshipit-source-id: c843802bc59da9f57c09c8d9a20f3d72d5b98edf
parent d0e16684
......@@ -6,11 +6,13 @@ import contextlib
import logging
import os
from collections import OrderedDict
from datetime import timedelta
from functools import lru_cache
from typing import Any, List, Optional, Type, Union
import detectron2.utils.comm as comm
import torch
import torch.distributed as dist
from d2go.checkpoint.api import is_distributed_checkpoint
from d2go.checkpoint.fsdp_checkpoint import FSDPCheckpointer
from d2go.config import CfgNode, CONFIG_SCALING_METHOD_REGISTRY, temp_defrost
......@@ -78,6 +80,7 @@ logger = logging.getLogger(__name__)
ALL_TB_WRITERS = []
CONTROL_PG_TIMEOUT = timedelta(minutes=30)
@lru_cache()
......@@ -156,10 +159,28 @@ def get_monitoring_service() -> Any:
return contextlib.nullcontext()
@fb_overwritable()
def create_preemption_hook(
cfg: CfgNode,
periodic_checkpointer: PeriodicCheckpointer,
process_group: Optional[dist.ProcessGroup],
) -> Any:
return None
class BaseRunner(object):
def __init__(self):
identifier = f"D2Go.Runner.{self.__class__.__name__}"
torch._C._log_api_usage_once(identifier)
# initialize the control pg for stuff like checkpoint and preemption handling
logger.info("Initializing control pg")
self._control_pg: Optional[dist.ProcessGroup] = None
if dist.is_initialized():
logger.info("Create gloo CPU control pg")
self._control_pg = dist.new_group(
backend=dist.Backend.GLOO,
timeout=CONTROL_PG_TIMEOUT,
)
def _initialize(self, cfg):
"""Runner should be initialized in the sub-process in ddp setting"""
......@@ -520,6 +541,7 @@ class Detectron2GoRunner(D2GoDataAPIMixIn, BaseRunner):
self._create_after_step_hook(
cfg, model, optimizer, scheduler, periodic_checkpointer
),
create_preemption_hook(cfg, periodic_checkpointer, self._control_pg),
hooks.EvalHook(
cfg.TEST.EVAL_PERIOD,
lambda: self.do_test(cfg, model, train_iter=trainer.iter),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment