Commit 6aec097e authored by Kai Zhang's avatar Kai Zhang Committed by Facebook GitHub Bot
Browse files

Support evaluate predictor

Summary:
Evaluate the predictor generated by previous step.
This diff modify the lightning_train_net to reuse the evaluation logic by adding a `predictor_path` param.
This diff also makes Lightning training backend depends on `cfg.MODEL.DEVICE` so that in evaluate_predictor step, user could set backend by changing model device. This is useful for evaluating int8 quantized model.

Reviewed By: newstzpz

Differential Revision: D27150609

fbshipit-source-id: fb72da3e81db932c0fa479350150720143e09a3e
parent 242b2d37
......@@ -15,6 +15,15 @@ from d2go.utils.testing.helper import tempdir
class TestLightningTrainNet(unittest.TestCase):
def setUp(self):
# set distributed backend to none to avoid spawning child process,
# which doesn't inherit the temporary dataset
patcher = unittest.mock.patch(
"d2go.tools.lightning_train_net.get_accelerator", return_value=None
)
self.addCleanup(patcher.stop)
patcher.start()
def _get_cfg(self, tmp_dir) -> CfgNode:
return mah.create_detection_cfg(GeneralizedRCNNTask, tmp_dir)
......@@ -24,14 +33,14 @@ class TestLightningTrainNet(unittest.TestCase):
cfg = self._get_cfg(root_dir)
# set distributed backend to none to avoid spawning child process,
# which doesn't inherit the temporary dataset
main(cfg, accelerator=None)
main(cfg)
@tempdir
def test_checkpointing(self, tmp_dir):
""" tests saving and loading from checkpoint. """
cfg = self._get_cfg(tmp_dir)
out = main(cfg, accelerator=None)
out = main(cfg)
ckpts = [file for file in os.listdir(tmp_dir) if file.endswith(".ckpt")]
self.assertCountEqual(
[
......@@ -48,7 +57,7 @@ class TestLightningTrainNet(unittest.TestCase):
# load the last checkpoint from previous training
cfg2.MODEL.WEIGHTS = os.path.join(tmp_dir, "last.ckpt")
out2 = main(cfg2, accelerator=None, eval_only=True)
out2 = main(cfg2, eval_only=True)
accuracy = flatten_config_dict(out.accuracy)
accuracy2 = flatten_config_dict(out2.accuracy)
for k in accuracy:
......
......@@ -80,6 +80,9 @@ def _get_trainer_callbacks(cfg: CfgNode) -> List[Callback]:
return callbacks
def get_accelerator(device: str) -> str:
return "ddp_cpu" if device.lower() == "cpu" else "ddp"
def do_train(cfg: CfgNode, trainer: pl.Trainer, task: GeneralizedRCNNTask) -> Dict[str, str]:
"""Runs the training loop with given trainer and task.
......@@ -129,7 +132,6 @@ def main(
num_machines: int = 1,
num_gpus: int = 0,
num_processes: int = 1,
accelerator: Optional[str] = "ddp",
) -> TrainOutput:
"""Main function for launching a training with lightning trainer
Args:
......@@ -139,8 +141,6 @@ def main(
num_processes: Number of processes on each node.
NOTE: Automatically set to the number of GPUs when using DDP.
Set a value greater than 1 to mimic distributed training on CPUs.
accelerator: Backend for distributed training. Only DDP
and DPP_CPU are supported.
eval_only: True if run evaluation only.
"""
assert (
......@@ -151,6 +151,7 @@ def main(
task = task_cls.from_config(cfg, eval_only)
tb_logger = TensorBoardLogger(save_dir=cfg.OUTPUT_DIR)
trainer_params = {
# training loop is bounded by max steps, use a large max_epochs to make
# sure max_steps is met first
......@@ -162,7 +163,7 @@ def main(
"num_nodes": num_machines,
"gpus": num_gpus,
"num_processes": num_processes,
"accelerator": accelerator,
"accelerator": get_accelerator(cfg.MODEL.DEVICE),
"callbacks": _get_trainer_callbacks(cfg),
"logger": tb_logger,
"num_sanity_val_steps": 0,
......@@ -229,7 +230,6 @@ if __name__ == "__main__":
num_machines=args.num_machines,
num_gpus=args.num_gpus,
num_processes=args.num_processes,
accelerator="ddp" if args.num_gpus > 0 else "ddp_cpu",
)
if get_rank() == 0:
print(ret)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment