Commit 8eab506b authored by Francisc Bungiu's avatar Francisc Bungiu Committed by Facebook GitHub Bot
Browse files

Add preemption checkpointing to lightning tasks

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/666

While debugging elevated preemption wastage in d2go, came across a few long running Pinocchio jobs in d2go that do not checkpoint preemption and also do not have checkpointing instrumented. This diff addresses both of these issues.

Reviewed By: wat3rBro

Differential Revision: D58669254

fbshipit-source-id: 9d1c5ff9e61a4a83d284a45154aa54d2d41178cf
parent 20054748
......@@ -9,6 +9,7 @@ from typing import Any, Dict, List, Optional, Tuple
import pytorch_lightning as pl
import torch
from d2go.checkpoint.checkpoint_instrumentation import instrument_checkpoint
from d2go.config import CfgNode
from d2go.data.datasets import inject_coco_datasets, register_dynamic_datasets
from d2go.data.utils import update_cfg_if_using_adhoc_dataset
......@@ -414,10 +415,12 @@ class DefaultTask(D2GoDataAPIMixIn, pl.LightningModule):
if self.ema_state and hasattr(self, "model_ema"):
del self.model_ema
@instrument_checkpoint("save")
def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
if self.ema_state:
checkpoint["model_ema"] = self.ema_state.state_dict()
@instrument_checkpoint("load")
def on_load_checkpoint(self, checkpointed_state: Dict[str, Any]) -> None:
"""
Called before model state is restored. Explicitly handles old model
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment