Commit e2537c82 authored by Yanghan Wang's avatar Yanghan Wang Committed by Facebook GitHub Bot
Browse files

separate TestNetOutput and TrainNetOutput

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/449

separate TestNetOutput and TrainNetOutput
- update d2go binaries
- update operators / workflows

Reviewed By: mcimpoi

Differential Revision: D42103714

fbshipit-source-id: 53f318c79d7339fb6fcfc3486e8b9cf249a598bf
parent ab49d0b6
......@@ -11,14 +11,20 @@ from typing import Dict, Optional
from d2go.evaluation.api import AccuracyDict, MetricsDict
# TODO (T127368935) Split to TrainNetOutput and TestNetOutput
@dataclass
class TrainNetOutput:
accuracy: AccuracyDict[float]
metrics: MetricsDict[float]
# Optional, because we use None to distinguish "not used" from
# empty model configs. With T127368935, this should be reverted to dict.
model_configs: Optional[Dict[str, str]]
model_configs: Dict[str, str]
# TODO (T127368603): decide if `tensorboard_log_dir` should be part of output
tensorboard_log_dir: Optional[str] = None
@dataclass
class TestNetOutput:
accuracy: AccuracyDict[float]
metrics: MetricsDict[float]
# TODO (T127368603): decide if `tensorboard_log_dir` should be part of output
tensorboard_log_dir: Optional[str] = None
......
......@@ -12,7 +12,7 @@ from d2go.config import CfgNode
from d2go.runner.callbacks.quantization import QuantizationAwareTraining
from d2go.runner.lightning_task import DefaultTask
from d2go.setup import basic_argument_parser, prepare_for_launch, setup_after_launch
from d2go.trainer.api import TrainNetOutput
from d2go.trainer.api import TestNetOutput, TrainNetOutput
from d2go.trainer.helper import parse_precision_from_string
from d2go.trainer.lightning.training_loop import _do_test, _do_train
from detectron2.utils.file_io import PathManager
......@@ -103,7 +103,7 @@ def main(
output_dir: str,
runner_class: Union[str, Type[DefaultTask]],
eval_only: bool = False,
) -> TrainNetOutput:
) -> Union[TrainNetOutput, TestNetOutput]:
"""Main function for launching a training with lightning trainer
Args:
cfg: D2go config node
......@@ -123,18 +123,22 @@ def main(
logger.info(f"Resuming training from checkpoint: {last_checkpoint}.")
trainer = pl.Trainer(**trainer_params)
model_configs = None
if eval_only:
_do_test(trainer, task)
return TestNetOutput(
tensorboard_log_dir=trainer_params["logger"].log_dir,
accuracy=task.eval_res,
metrics=task.eval_res,
)
else:
model_configs = _do_train(cfg, trainer, task)
return TrainNetOutput(
tensorboard_log_dir=trainer_params["logger"].log_dir,
accuracy=task.eval_res,
metrics=task.eval_res,
model_configs=model_configs,
)
return TrainNetOutput(
tensorboard_log_dir=trainer_params["logger"].log_dir,
accuracy=task.eval_res,
metrics=task.eval_res,
model_configs=model_configs,
)
def argument_parser():
......
......@@ -9,7 +9,6 @@ import logging
import sys
from typing import List, Type, Union
import detectron2.utils.comm as comm
from d2go.config import CfgNode
from d2go.distributed import launch
from d2go.runner import BaseRunner
......@@ -22,7 +21,7 @@ from d2go.setup import (
setup_before_launch,
setup_root_logger,
)
from d2go.trainer.api import TrainNetOutput
from d2go.trainer.api import TestNetOutput, TrainNetOutput
from d2go.trainer.fsdp import create_ddp_model_with_sharding
from d2go.utils.misc import (
dump_trained_model_configs,
......@@ -40,7 +39,7 @@ def main(
runner_class: Union[str, Type[BaseRunner]],
eval_only: bool = False,
resume: bool = True, # NOTE: always enable resume when running on cluster
) -> TrainNetOutput:
) -> Union[TrainNetOutput, TestNetOutput]:
runner = setup_after_launch(cfg, output_dir, runner_class)
model = runner.build_model(cfg)
......@@ -58,9 +57,8 @@ def main(
model.eval()
metrics = runner.do_test(cfg, model, train_iter=train_iter)
print_metrics_table(metrics)
return TrainNetOutput(
return TestNetOutput(
accuracy=metrics,
model_configs={},
metrics=metrics,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment