"tests/vscode:/vscode.git/clone" did not exist on "f64d52dbca93051a7652db7aa241964235a71035"
Commit 2228d3a0 authored by Mircea Cimpoi's avatar Mircea Cimpoi Committed by Facebook GitHub Bot
Browse files

EMA - add test for loading from checkpoint for eval-only

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/348

Add testcase to ensure loading from config in eval_only is covered.

Reviewed By: wat3rBro

Differential Revision: D38001319

fbshipit-source-id: e6a2edb5001ae87606a3bf48e1355037aee0f9a0
parent 3c811d21
......@@ -5,6 +5,7 @@
import os
import unittest
from copy import deepcopy
from tempfile import TemporaryDirectory
from typing import Dict
import pytorch_lightning as pl # type: ignore
......@@ -91,7 +92,7 @@ class TestLightningTask(unittest.TestCase):
)
@tempdir
def test_load_ema_weights(self, tmp_dir):
def test_load_ema_weights(self, tmp_dir) -> None:
cfg = self._get_cfg(tmp_dir)
cfg.MODEL_EMA.ENABLED = True
task = GeneralizedRCNNTask(cfg)
......@@ -118,6 +119,40 @@ class TestLightningTask(unittest.TestCase):
)
)
@tempdir
def test_ema_eval_only_mode(self, tmp_dir: TemporaryDirectory) -> None:
"""Train one model for one iteration, then check if the
second task is loaded correctly from config and applied to model.x"""
cfg = self._get_cfg(tmp_dir)
cfg.MODEL.MODELING_HOOKS = ["EMA"]
cfg.MODEL_EMA.ENABLED = True
task = GeneralizedRCNNTask(cfg)
trainer = self._get_trainer(tmp_dir)
with EventStorage() as storage:
task.storage = storage
trainer.fit(task)
# load EMA weights from checkpoint
cfg2 = self._get_cfg(tmp_dir)
cfg2.MODEL.MODELING_HOOKS = ["EMA"]
cfg2.MODEL_EMA.ENABLED = True
cfg2.MODEL_EMA.USE_EMA_WEIGHTS_FOR_EVAL_ONLY = True
cfg2.MODEL.WEIGHTS = os.path.join(tmp_dir, "last.ckpt")
task2 = GeneralizedRCNNTask.from_config(cfg2)
self.assertTrue(task2.ema_state, "EMA state is not loaded from checkpoint.")
self.assertTrue(
len(task2.ema_state.state_dict()) > 0, "EMA state should not be empty."
)
self.assertTrue(
self._compare_state_dict(
task.ema_state.state_dict(), task2.model.state_dict()
),
"Task loaded from config should apply the ema_state to the model.",
)
def test_create_runner(self):
task_cls = create_runner(
f"{GeneralizedRCNNTask.__module__}.{GeneralizedRCNNTask.__name__}"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment