Commit 2b5a3176 authored by Kai Zhang's avatar Kai Zhang Committed by Facebook GitHub Bot
Browse files

update Lightning module test for OSS

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/17

Use PyTorch Lightning checkpoint in the test.

Reviewed By: zhanghang1989

Differential Revision: D26962697

fbshipit-source-id: abe635e374c3ada130243f0eaadff34204f04fa1
parent 8407e5f2
...@@ -41,10 +41,10 @@ class TestLightningTrainNet(unittest.TestCase): ...@@ -41,10 +41,10 @@ class TestLightningTrainNet(unittest.TestCase):
ckpts, ckpts,
) )
with tempfile.TemporaryDirectory() as tmp_dir2: tmp_dir2 = tempfile.TemporaryDirectory() # noqa to avoid flaky test
cfg2 = cfg.clone() cfg2 = cfg.clone()
cfg2.defrost() cfg2.defrost()
cfg2.OUTPUT_DIR = tmp_dir2 cfg2.OUTPUT_DIR = tmp_dir2.name
# load the last checkpoint from previous training # load the last checkpoint from previous training
cfg2.MODEL.WEIGHTS = os.path.join(tmp_dir, "last.ckpt") cfg2.MODEL.WEIGHTS = os.path.join(tmp_dir, "last.ckpt")
...@@ -53,3 +53,4 @@ class TestLightningTrainNet(unittest.TestCase): ...@@ -53,3 +53,4 @@ class TestLightningTrainNet(unittest.TestCase):
accuracy2 = flatten_config_dict(out2.accuracy) accuracy2 = flatten_config_dict(out2.accuracy)
for k in accuracy: for k in accuracy:
np.testing.assert_equal(accuracy[k], accuracy2[k]) np.testing.assert_equal(accuracy[k], accuracy2[k])
tmp_dir2.cleanup()
...@@ -8,22 +8,19 @@ import unittest ...@@ -8,22 +8,19 @@ import unittest
from copy import deepcopy from copy import deepcopy
from typing import Dict from typing import Dict
import d2go.runner.default_runner as default_runner
import pytorch_lightning as pl # type: ignore import pytorch_lightning as pl # type: ignore
import torch import torch
from d2go.config import CfgNode from d2go.config import CfgNode
from d2go.runner.lightning_task import GeneralizedRCNNTask from d2go.runner.lightning_task import GeneralizedRCNNTask
from d2go.tests import meta_arch_helper as mah
from detectron2.utils.events import EventStorage from detectron2.utils.events import EventStorage
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from torch import Tensor from torch import Tensor
from d2go.tests import meta_arch_helper as mah
OSSRUN = os.getenv('OSSRUN') == '1'
class TestLightningTask(unittest.TestCase): class TestLightningTask(unittest.TestCase):
def _get_cfg(self, tmp_dir: str) -> CfgNode: def _get_cfg(self, tmp_dir: str) -> CfgNode:
runner = default_runner.Detectron2GoRunner() cfg = mah.create_detection_cfg(GeneralizedRCNNTask, tmp_dir)
cfg = mah.create_detection_cfg(runner, tmp_dir)
cfg.TEST.EVAL_PERIOD = cfg.SOLVER.MAX_ITER cfg.TEST.EVAL_PERIOD = cfg.SOLVER.MAX_ITER
return cfg return cfg
...@@ -38,14 +35,11 @@ class TestLightningTask(unittest.TestCase): ...@@ -38,14 +35,11 @@ class TestLightningTask(unittest.TestCase):
return False return False
return True return True
@unittest.skipIf(OSSRUN, "not supported yet")
def test_load_from_checkpoint(self) -> None: def test_load_from_checkpoint(self) -> None:
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
task = GeneralizedRCNNTask(self._get_cfg(tmp_dir)) task = GeneralizedRCNNTask(self._get_cfg(tmp_dir))
from stl.lightning.callbacks.model_checkpoint import ModelCheckpoint
checkpoint_callback = ModelCheckpoint( checkpoint_callback = ModelCheckpoint(dirpath=task.cfg.OUTPUT_DIR)
directory=task.cfg.OUTPUT_DIR, has_user_data=False
)
params = { params = {
"max_steps": 1, "max_steps": 1,
"limit_train_batches": 1, "limit_train_batches": 1,
...@@ -92,15 +86,13 @@ class TestLightningTask(unittest.TestCase): ...@@ -92,15 +86,13 @@ class TestLightningTask(unittest.TestCase):
self._compare_state_dict(init_state, task.ema_state.state_dict()) self._compare_state_dict(init_state, task.ema_state.state_dict())
) )
@unittest.skipIf(OSSRUN, "not supported yet")
def test_load_ema_weights(self): def test_load_ema_weights(self):
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
cfg = self._get_cfg(tmp_dir) cfg = self._get_cfg(tmp_dir)
cfg.MODEL_EMA.ENABLED = True cfg.MODEL_EMA.ENABLED = True
task = GeneralizedRCNNTask(cfg) task = GeneralizedRCNNTask(cfg)
from stl.lightning.callbacks.model_checkpoint import ModelCheckpoint
checkpoint_callback = ModelCheckpoint( checkpoint_callback = ModelCheckpoint(
directory=task.cfg.OUTPUT_DIR, save_last=True dirpath=task.cfg.OUTPUT_DIR, save_last=True
) )
trainer = pl.Trainer( trainer = pl.Trainer(
...@@ -115,9 +107,19 @@ class TestLightningTask(unittest.TestCase): ...@@ -115,9 +107,19 @@ class TestLightningTask(unittest.TestCase):
trainer.fit(task) trainer.fit(task)
# load EMA weights from checkpoint # load EMA weights from checkpoint
task2 = GeneralizedRCNNTask.load_from_checkpoint(os.path.join(tmp_dir, "last.ckpt")) task2 = GeneralizedRCNNTask.load_from_checkpoint(
self.assertTrue(self._compare_state_dict(task.ema_state.state_dict(), task2.ema_state.state_dict())) os.path.join(tmp_dir, "last.ckpt")
)
self.assertTrue(
self._compare_state_dict(
task.ema_state.state_dict(), task2.ema_state.state_dict()
)
)
# apply EMA weights to model # apply EMA weights to model
task2.ema_state.apply_to(task2.model) task2.ema_state.apply_to(task2.model)
self.assertTrue(self._compare_state_dict(task.ema_state.state_dict(), task2.model.state_dict())) self.assertTrue(
self._compare_state_dict(
task.ema_state.state_dict(), task2.model.state_dict()
)
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment