Commit fc5616c8 authored by Ananth Subramaniam's avatar Ananth Subramaniam Committed by Facebook GitHub Bot
Browse files

Make checkpointing tests slightly less restrictive

Summary:
Before: this test would assume only 2 checkpoints were stored: `last.ckpt`, and `FINAL_MODEL_CKPT`
Now: this test asserts that at least these 2 checkpoints are stored. In case the config specifies `save_top_k=-1` for instance, we'd save more checkpoints, causing this test to fail

Since this test is only loading the last and the final outputs, I'm changing the behavior to assert that these checkpoints must be saved and ignoring other checkpoint files that could be generated.

Reviewed By: kazhang

Differential Revision: D27671284

fbshipit-source-id: 0419fb46856d048e7b6eba3ff1dc65b7280a9a90
parent e47d6a24
......@@ -40,18 +40,14 @@ class TestLightningTrainNet(unittest.TestCase):
cfg = self._get_cfg(tmp_dir)
out = main(cfg)
ckpts = [file for file in os.listdir(tmp_dir) if file.endswith(".ckpt")]
self.assertCountEqual(
[
"last.ckpt",
FINAL_MODEL_CKPT,
],
ckpts,
)
ckpts = [f for f in os.listdir(tmp_dir) if f.endswith(".ckpt")]
expected_ckpts = ("last.ckpt", FINAL_MODEL_CKPT)
for ckpt in expected_ckpts:
self.assertIn(ckpt, ckpts)
cfg2 = cfg.clone()
cfg2.defrost()
cfg2.OUTPUT_DIR = os.path.join(tmp_dir, 'output')
cfg2.OUTPUT_DIR = os.path.join(tmp_dir, "output")
# load the last checkpoint from previous training
cfg2.MODEL.WEIGHTS = os.path.join(tmp_dir, "last.ckpt")
......
......@@ -54,7 +54,7 @@ def _get_trainer_callbacks(cfg: CfgNode) -> List[Callback]:
cfg: The normalized ConfigNode for this D2Go Task.
Returns:
A list of configured Callbacks to be used by the Lightning Traininer.
A list of configured Callbacks to be used by the Lightning Trainer.
"""
callbacks: List[Callback] = [
LearningRateMonitor(logging_interval="step"),
......@@ -84,7 +84,9 @@ def get_accelerator(device: str) -> str:
return "ddp_cpu" if device.lower() == "cpu" else "ddp"
def do_train(cfg: CfgNode, trainer: pl.Trainer, task: GeneralizedRCNNTask) -> Dict[str, str]:
def do_train(
cfg: CfgNode, trainer: pl.Trainer, task: GeneralizedRCNNTask
) -> Dict[str, str]:
"""Runs the training loop with given trainer and task.
Args:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment