Commit 1a7f16bb authored by Kai Zhang's avatar Kai Zhang Committed by Facebook GitHub Bot
Browse files

disable logger in Lightning task test

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/46

As titled. The test is flaky because the tensorboard logger might still be writing to temporary folder when we tear down the folder.

Reviewed By: ananthsub

Differential Revision: D27844504

fbshipit-source-id: 3987f9ec3cd05b2f193e75cd4d85109a46f4ee71
parent 70f157c5
......@@ -31,6 +31,16 @@ class TestLightningTask(unittest.TestCase):
cfg.TEST.EVAL_PERIOD = cfg.SOLVER.MAX_ITER
return cfg
def _get_trainer(self, output_dir: str) -> pl.Trainer:
checkpoint_callback = ModelCheckpoint(dirpath=output_dir, save_last=True)
return pl.Trainer(
max_steps=1,
limit_train_batches=1,
num_sanity_val_steps=0,
callbacks=[checkpoint_callback],
logger=None,
)
def _compare_state_dict(
self, state1: Dict[str, Tensor], state2: Dict[str, Tensor]
) -> bool:
......@@ -46,14 +56,7 @@ class TestLightningTask(unittest.TestCase):
def test_load_from_checkpoint(self, tmp_dir) -> None:
task = GeneralizedRCNNTask(self._get_cfg(tmp_dir))
checkpoint_callback = ModelCheckpoint(dirpath=task.cfg.OUTPUT_DIR)
params = {
"max_steps": 1,
"limit_train_batches": 1,
"num_sanity_val_steps": 0,
"callbacks": [checkpoint_callback],
}
trainer = pl.Trainer(**params)
trainer = self._get_trainer(tmp_dir)
with EventStorage() as storage:
task.storage = storage
trainer.fit(task)
......@@ -77,11 +80,7 @@ class TestLightningTask(unittest.TestCase):
task = GeneralizedRCNNTask(cfg)
init_state = deepcopy(task.model.state_dict())
trainer = pl.Trainer(
max_steps=1,
limit_train_batches=1,
num_sanity_val_steps=0,
)
trainer = self._get_trainer(tmp_dir)
with EventStorage() as storage:
task.storage = storage
trainer.fit(task)
......@@ -98,17 +97,7 @@ class TestLightningTask(unittest.TestCase):
cfg = self._get_cfg(tmp_dir)
cfg.MODEL_EMA.ENABLED = True
task = GeneralizedRCNNTask(cfg)
checkpoint_callback = ModelCheckpoint(
dirpath=task.cfg.OUTPUT_DIR, save_last=True
)
trainer = pl.Trainer(
max_steps=1,
limit_train_batches=1,
num_sanity_val_steps=0,
callbacks=[checkpoint_callback],
)
trainer = self._get_trainer(tmp_dir)
with EventStorage() as storage:
task.storage = storage
trainer.fit(task)
......@@ -142,16 +131,7 @@ class TestLightningTask(unittest.TestCase):
cfg = self._get_cfg(tmp_dir)
cfg.MODEL_EMA.ENABLED = True
task = GeneralizedRCNNTask(cfg)
checkpoint_callback = ModelCheckpoint(
dirpath=task.cfg.OUTPUT_DIR, save_last=True
)
trainer = pl.Trainer(
max_steps=1,
limit_train_batches=1,
num_sanity_val_steps=0,
callbacks=[checkpoint_callback],
)
trainer = self._get_trainer(tmp_dir)
with EventStorage() as storage:
task.storage = storage
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment