Commit 6aec097e authored by Kai Zhang's avatar Kai Zhang Committed by Facebook GitHub Bot
Browse files

Support evaluate predictor

Summary:
Evaluate the predictor generated by previous step.
This diff modify the lightning_train_net to reuse the evaluation logic by adding a `predictor_path` param.
This diff also makes Lightning training backend depends on `cfg.MODEL.DEVICE` so that in evaluate_predictor step, user could set backend by changing model device. This is useful for evaluating int8 quantized model.

Reviewed By: newstzpz

Differential Revision: D27150609

fbshipit-source-id: fb72da3e81db932c0fa479350150720143e09a3e
parent 242b2d37
...@@ -15,6 +15,15 @@ from d2go.utils.testing.helper import tempdir ...@@ -15,6 +15,15 @@ from d2go.utils.testing.helper import tempdir
class TestLightningTrainNet(unittest.TestCase): class TestLightningTrainNet(unittest.TestCase):
def setUp(self):
# set distributed backend to none to avoid spawning child process,
# which doesn't inherit the temporary dataset
patcher = unittest.mock.patch(
"d2go.tools.lightning_train_net.get_accelerator", return_value=None
)
self.addCleanup(patcher.stop)
patcher.start()
def _get_cfg(self, tmp_dir) -> CfgNode: def _get_cfg(self, tmp_dir) -> CfgNode:
return mah.create_detection_cfg(GeneralizedRCNNTask, tmp_dir) return mah.create_detection_cfg(GeneralizedRCNNTask, tmp_dir)
...@@ -24,14 +33,14 @@ class TestLightningTrainNet(unittest.TestCase): ...@@ -24,14 +33,14 @@ class TestLightningTrainNet(unittest.TestCase):
cfg = self._get_cfg(root_dir) cfg = self._get_cfg(root_dir)
# set distributed backend to none to avoid spawning child process, # set distributed backend to none to avoid spawning child process,
# which doesn't inherit the temporary dataset # which doesn't inherit the temporary dataset
main(cfg, accelerator=None) main(cfg)
@tempdir @tempdir
def test_checkpointing(self, tmp_dir): def test_checkpointing(self, tmp_dir):
""" tests saving and loading from checkpoint. """ """ tests saving and loading from checkpoint. """
cfg = self._get_cfg(tmp_dir) cfg = self._get_cfg(tmp_dir)
out = main(cfg, accelerator=None) out = main(cfg)
ckpts = [file for file in os.listdir(tmp_dir) if file.endswith(".ckpt")] ckpts = [file for file in os.listdir(tmp_dir) if file.endswith(".ckpt")]
self.assertCountEqual( self.assertCountEqual(
[ [
...@@ -48,7 +57,7 @@ class TestLightningTrainNet(unittest.TestCase): ...@@ -48,7 +57,7 @@ class TestLightningTrainNet(unittest.TestCase):
# load the last checkpoint from previous training # load the last checkpoint from previous training
cfg2.MODEL.WEIGHTS = os.path.join(tmp_dir, "last.ckpt") cfg2.MODEL.WEIGHTS = os.path.join(tmp_dir, "last.ckpt")
out2 = main(cfg2, accelerator=None, eval_only=True) out2 = main(cfg2, eval_only=True)
accuracy = flatten_config_dict(out.accuracy) accuracy = flatten_config_dict(out.accuracy)
accuracy2 = flatten_config_dict(out2.accuracy) accuracy2 = flatten_config_dict(out2.accuracy)
for k in accuracy: for k in accuracy:
......
...@@ -80,6 +80,9 @@ def _get_trainer_callbacks(cfg: CfgNode) -> List[Callback]: ...@@ -80,6 +80,9 @@ def _get_trainer_callbacks(cfg: CfgNode) -> List[Callback]:
return callbacks return callbacks
def get_accelerator(device: str) -> str:
return "ddp_cpu" if device.lower() == "cpu" else "ddp"
def do_train(cfg: CfgNode, trainer: pl.Trainer, task: GeneralizedRCNNTask) -> Dict[str, str]: def do_train(cfg: CfgNode, trainer: pl.Trainer, task: GeneralizedRCNNTask) -> Dict[str, str]:
"""Runs the training loop with given trainer and task. """Runs the training loop with given trainer and task.
...@@ -129,7 +132,6 @@ def main( ...@@ -129,7 +132,6 @@ def main(
num_machines: int = 1, num_machines: int = 1,
num_gpus: int = 0, num_gpus: int = 0,
num_processes: int = 1, num_processes: int = 1,
accelerator: Optional[str] = "ddp",
) -> TrainOutput: ) -> TrainOutput:
"""Main function for launching a training with lightning trainer """Main function for launching a training with lightning trainer
Args: Args:
...@@ -139,8 +141,6 @@ def main( ...@@ -139,8 +141,6 @@ def main(
num_processes: Number of processes on each node. num_processes: Number of processes on each node.
NOTE: Automatically set to the number of GPUs when using DDP. NOTE: Automatically set to the number of GPUs when using DDP.
Set a value greater than 1 to mimic distributed training on CPUs. Set a value greater than 1 to mimic distributed training on CPUs.
accelerator: Backend for distributed training. Only DDP
and DPP_CPU are supported.
eval_only: True if run evaluation only. eval_only: True if run evaluation only.
""" """
assert ( assert (
...@@ -151,6 +151,7 @@ def main( ...@@ -151,6 +151,7 @@ def main(
task = task_cls.from_config(cfg, eval_only) task = task_cls.from_config(cfg, eval_only)
tb_logger = TensorBoardLogger(save_dir=cfg.OUTPUT_DIR) tb_logger = TensorBoardLogger(save_dir=cfg.OUTPUT_DIR)
trainer_params = { trainer_params = {
# training loop is bounded by max steps, use a large max_epochs to make # training loop is bounded by max steps, use a large max_epochs to make
# sure max_steps is met first # sure max_steps is met first
...@@ -162,7 +163,7 @@ def main( ...@@ -162,7 +163,7 @@ def main(
"num_nodes": num_machines, "num_nodes": num_machines,
"gpus": num_gpus, "gpus": num_gpus,
"num_processes": num_processes, "num_processes": num_processes,
"accelerator": accelerator, "accelerator": get_accelerator(cfg.MODEL.DEVICE),
"callbacks": _get_trainer_callbacks(cfg), "callbacks": _get_trainer_callbacks(cfg),
"logger": tb_logger, "logger": tb_logger,
"num_sanity_val_steps": 0, "num_sanity_val_steps": 0,
...@@ -229,7 +230,6 @@ if __name__ == "__main__": ...@@ -229,7 +230,6 @@ if __name__ == "__main__":
num_machines=args.num_machines, num_machines=args.num_machines,
num_gpus=args.num_gpus, num_gpus=args.num_gpus,
num_processes=args.num_processes, num_processes=args.num_processes,
accelerator="ddp" if args.num_gpus > 0 else "ddp_cpu",
) )
if get_rank() == 0: if get_rank() == 0:
print(ret) print(ret)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment