"vscode:/vscode.git/clone" did not exist on "69f32b7fc91efd1d2c26f3b759e87509775ba4f0"
Commit 1a7f16bb authored by Kai Zhang's avatar Kai Zhang Committed by Facebook GitHub Bot
Browse files

disable logger in Lightning task test

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/46

As titled. The test is flaky because the tensorboard logger might still be writing to temporary folder when we tear down the folder.

Reviewed By: ananthsub

Differential Revision: D27844504

fbshipit-source-id: 3987f9ec3cd05b2f193e75cd4d85109a46f4ee71
parent 70f157c5
...@@ -31,6 +31,16 @@ class TestLightningTask(unittest.TestCase): ...@@ -31,6 +31,16 @@ class TestLightningTask(unittest.TestCase):
cfg.TEST.EVAL_PERIOD = cfg.SOLVER.MAX_ITER cfg.TEST.EVAL_PERIOD = cfg.SOLVER.MAX_ITER
return cfg return cfg
def _get_trainer(self, output_dir: str) -> pl.Trainer:
checkpoint_callback = ModelCheckpoint(dirpath=output_dir, save_last=True)
return pl.Trainer(
max_steps=1,
limit_train_batches=1,
num_sanity_val_steps=0,
callbacks=[checkpoint_callback],
logger=None,
)
def _compare_state_dict( def _compare_state_dict(
self, state1: Dict[str, Tensor], state2: Dict[str, Tensor] self, state1: Dict[str, Tensor], state2: Dict[str, Tensor]
) -> bool: ) -> bool:
...@@ -46,14 +56,7 @@ class TestLightningTask(unittest.TestCase): ...@@ -46,14 +56,7 @@ class TestLightningTask(unittest.TestCase):
def test_load_from_checkpoint(self, tmp_dir) -> None: def test_load_from_checkpoint(self, tmp_dir) -> None:
task = GeneralizedRCNNTask(self._get_cfg(tmp_dir)) task = GeneralizedRCNNTask(self._get_cfg(tmp_dir))
checkpoint_callback = ModelCheckpoint(dirpath=task.cfg.OUTPUT_DIR) trainer = self._get_trainer(tmp_dir)
params = {
"max_steps": 1,
"limit_train_batches": 1,
"num_sanity_val_steps": 0,
"callbacks": [checkpoint_callback],
}
trainer = pl.Trainer(**params)
with EventStorage() as storage: with EventStorage() as storage:
task.storage = storage task.storage = storage
trainer.fit(task) trainer.fit(task)
...@@ -77,11 +80,7 @@ class TestLightningTask(unittest.TestCase): ...@@ -77,11 +80,7 @@ class TestLightningTask(unittest.TestCase):
task = GeneralizedRCNNTask(cfg) task = GeneralizedRCNNTask(cfg)
init_state = deepcopy(task.model.state_dict()) init_state = deepcopy(task.model.state_dict())
trainer = pl.Trainer( trainer = self._get_trainer(tmp_dir)
max_steps=1,
limit_train_batches=1,
num_sanity_val_steps=0,
)
with EventStorage() as storage: with EventStorage() as storage:
task.storage = storage task.storage = storage
trainer.fit(task) trainer.fit(task)
...@@ -98,17 +97,7 @@ class TestLightningTask(unittest.TestCase): ...@@ -98,17 +97,7 @@ class TestLightningTask(unittest.TestCase):
cfg = self._get_cfg(tmp_dir) cfg = self._get_cfg(tmp_dir)
cfg.MODEL_EMA.ENABLED = True cfg.MODEL_EMA.ENABLED = True
task = GeneralizedRCNNTask(cfg) task = GeneralizedRCNNTask(cfg)
checkpoint_callback = ModelCheckpoint( trainer = self._get_trainer(tmp_dir)
dirpath=task.cfg.OUTPUT_DIR, save_last=True
)
trainer = pl.Trainer(
max_steps=1,
limit_train_batches=1,
num_sanity_val_steps=0,
callbacks=[checkpoint_callback],
)
with EventStorage() as storage: with EventStorage() as storage:
task.storage = storage task.storage = storage
trainer.fit(task) trainer.fit(task)
...@@ -142,16 +131,7 @@ class TestLightningTask(unittest.TestCase): ...@@ -142,16 +131,7 @@ class TestLightningTask(unittest.TestCase):
cfg = self._get_cfg(tmp_dir) cfg = self._get_cfg(tmp_dir)
cfg.MODEL_EMA.ENABLED = True cfg.MODEL_EMA.ENABLED = True
task = GeneralizedRCNNTask(cfg) task = GeneralizedRCNNTask(cfg)
checkpoint_callback = ModelCheckpoint( trainer = self._get_trainer(tmp_dir)
dirpath=task.cfg.OUTPUT_DIR, save_last=True
)
trainer = pl.Trainer(
max_steps=1,
limit_train_batches=1,
num_sanity_val_steps=0,
callbacks=[checkpoint_callback],
)
with EventStorage() as storage: with EventStorage() as storage:
task.storage = storage task.storage = storage
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment