Commit e2537c82 authored by Yanghan Wang's avatar Yanghan Wang Committed by Facebook GitHub Bot
Browse files

separate TestNetOutput and TrainNetOutput

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/449

separate TestNetOutput and TrainNetOutput
- update d2go binaries
- update operators / workflows

Reviewed By: mcimpoi

Differential Revision: D42103714

fbshipit-source-id: 53f318c79d7339fb6fcfc3486e8b9cf249a598bf
parent ab49d0b6
...@@ -11,14 +11,20 @@ from typing import Dict, Optional ...@@ -11,14 +11,20 @@ from typing import Dict, Optional
from d2go.evaluation.api import AccuracyDict, MetricsDict from d2go.evaluation.api import AccuracyDict, MetricsDict
# TODO (T127368935) Split to TrainNetOutput and TestNetOutput
@dataclass @dataclass
class TrainNetOutput: class TrainNetOutput:
accuracy: AccuracyDict[float] accuracy: AccuracyDict[float]
metrics: MetricsDict[float] metrics: MetricsDict[float]
# Optional, because we use None to distinguish "not used" from model_configs: Dict[str, str]
# empty model configs. With T127368935, this should be reverted to dict. # TODO (T127368603): decide if `tensorboard_log_dir` should be part of output
model_configs: Optional[Dict[str, str]] tensorboard_log_dir: Optional[str] = None
@dataclass
class TestNetOutput:
accuracy: AccuracyDict[float]
metrics: MetricsDict[float]
# TODO (T127368603): decide if `tensorboard_log_dir` should be part of output # TODO (T127368603): decide if `tensorboard_log_dir` should be part of output
tensorboard_log_dir: Optional[str] = None tensorboard_log_dir: Optional[str] = None
......
...@@ -12,7 +12,7 @@ from d2go.config import CfgNode ...@@ -12,7 +12,7 @@ from d2go.config import CfgNode
from d2go.runner.callbacks.quantization import QuantizationAwareTraining from d2go.runner.callbacks.quantization import QuantizationAwareTraining
from d2go.runner.lightning_task import DefaultTask from d2go.runner.lightning_task import DefaultTask
from d2go.setup import basic_argument_parser, prepare_for_launch, setup_after_launch from d2go.setup import basic_argument_parser, prepare_for_launch, setup_after_launch
from d2go.trainer.api import TrainNetOutput from d2go.trainer.api import TestNetOutput, TrainNetOutput
from d2go.trainer.helper import parse_precision_from_string from d2go.trainer.helper import parse_precision_from_string
from d2go.trainer.lightning.training_loop import _do_test, _do_train from d2go.trainer.lightning.training_loop import _do_test, _do_train
from detectron2.utils.file_io import PathManager from detectron2.utils.file_io import PathManager
...@@ -103,7 +103,7 @@ def main( ...@@ -103,7 +103,7 @@ def main(
output_dir: str, output_dir: str,
runner_class: Union[str, Type[DefaultTask]], runner_class: Union[str, Type[DefaultTask]],
eval_only: bool = False, eval_only: bool = False,
) -> TrainNetOutput: ) -> Union[TrainNetOutput, TestNetOutput]:
"""Main function for launching a training with lightning trainer """Main function for launching a training with lightning trainer
Args: Args:
cfg: D2go config node cfg: D2go config node
...@@ -123,18 +123,22 @@ def main( ...@@ -123,18 +123,22 @@ def main(
logger.info(f"Resuming training from checkpoint: {last_checkpoint}.") logger.info(f"Resuming training from checkpoint: {last_checkpoint}.")
trainer = pl.Trainer(**trainer_params) trainer = pl.Trainer(**trainer_params)
model_configs = None
if eval_only: if eval_only:
_do_test(trainer, task) _do_test(trainer, task)
return TestNetOutput(
tensorboard_log_dir=trainer_params["logger"].log_dir,
accuracy=task.eval_res,
metrics=task.eval_res,
)
else: else:
model_configs = _do_train(cfg, trainer, task) model_configs = _do_train(cfg, trainer, task)
return TrainNetOutput(
return TrainNetOutput( tensorboard_log_dir=trainer_params["logger"].log_dir,
tensorboard_log_dir=trainer_params["logger"].log_dir, accuracy=task.eval_res,
accuracy=task.eval_res, metrics=task.eval_res,
metrics=task.eval_res, model_configs=model_configs,
model_configs=model_configs, )
)
def argument_parser(): def argument_parser():
......
...@@ -9,7 +9,6 @@ import logging ...@@ -9,7 +9,6 @@ import logging
import sys import sys
from typing import List, Type, Union from typing import List, Type, Union
import detectron2.utils.comm as comm
from d2go.config import CfgNode from d2go.config import CfgNode
from d2go.distributed import launch from d2go.distributed import launch
from d2go.runner import BaseRunner from d2go.runner import BaseRunner
...@@ -22,7 +21,7 @@ from d2go.setup import ( ...@@ -22,7 +21,7 @@ from d2go.setup import (
setup_before_launch, setup_before_launch,
setup_root_logger, setup_root_logger,
) )
from d2go.trainer.api import TrainNetOutput from d2go.trainer.api import TestNetOutput, TrainNetOutput
from d2go.trainer.fsdp import create_ddp_model_with_sharding from d2go.trainer.fsdp import create_ddp_model_with_sharding
from d2go.utils.misc import ( from d2go.utils.misc import (
dump_trained_model_configs, dump_trained_model_configs,
...@@ -40,7 +39,7 @@ def main( ...@@ -40,7 +39,7 @@ def main(
runner_class: Union[str, Type[BaseRunner]], runner_class: Union[str, Type[BaseRunner]],
eval_only: bool = False, eval_only: bool = False,
resume: bool = True, # NOTE: always enable resume when running on cluster resume: bool = True, # NOTE: always enable resume when running on cluster
) -> TrainNetOutput: ) -> Union[TrainNetOutput, TestNetOutput]:
runner = setup_after_launch(cfg, output_dir, runner_class) runner = setup_after_launch(cfg, output_dir, runner_class)
model = runner.build_model(cfg) model = runner.build_model(cfg)
...@@ -58,9 +57,8 @@ def main( ...@@ -58,9 +57,8 @@ def main(
model.eval() model.eval()
metrics = runner.do_test(cfg, model, train_iter=train_iter) metrics = runner.do_test(cfg, model, train_iter=train_iter)
print_metrics_table(metrics) print_metrics_table(metrics)
return TrainNetOutput( return TestNetOutput(
accuracy=metrics, accuracy=metrics,
model_configs={},
metrics=metrics, metrics=metrics,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment