Commit dba54f21 authored by Mik Vyatskov's avatar Mik Vyatskov Committed by Facebook GitHub Bot
Browse files

Move TrainNetOutput from the binary to the library

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/357

This change makes it possible to unpickle TrainNetOutput which is currently cannot be unpickled because it's a part of main module which can be different for the binary that's unpickling this dataclass.

Reviewed By: miqueljubert

Differential Revision: D38536040

fbshipit-source-id: 856594251b2eca7630d69c7917bc4746859dab9f
parent 9a7b2e0f
...@@ -6,7 +6,21 @@ ...@@ -6,7 +6,21 @@
Trainer APIs on which D2Go's binary can build on top. Trainer APIs on which D2Go's binary can build on top.
""" """
# TODO: placeholder for now from dataclasses import dataclass
from typing import Dict, Optional
from d2go.evaluation.api import AccuracyDict, MetricsDict
# TODO (T127368935) Split to TrainNetOutput and TestNetOutput
@dataclass
class TrainNetOutput:
accuracy: AccuracyDict[float]
metrics: MetricsDict[float]
# Optional, because we use None to distinguish "not used" from
# empty model configs. With T127368935, this should be reverted to dict.
model_configs: Optional[Dict[str, str]]
# TODO (T127368603): decide if `tensorboard_log_dir` should be part of output
tensorboard_log_dir: Optional[str] = None
def do_train(): def do_train():
......
...@@ -12,7 +12,7 @@ from d2go.config import CfgNode ...@@ -12,7 +12,7 @@ from d2go.config import CfgNode
from d2go.runner.callbacks.quantization import QuantizationAwareTraining from d2go.runner.callbacks.quantization import QuantizationAwareTraining
from d2go.runner.lightning_task import DefaultTask from d2go.runner.lightning_task import DefaultTask
from d2go.setup import basic_argument_parser, prepare_for_launch, setup_after_launch from d2go.setup import basic_argument_parser, prepare_for_launch, setup_after_launch
from d2go.tools.train_net import TrainNetOutput from d2go.trainer.api import TrainNetOutput
from d2go.trainer.lightning.training_loop import _do_test, _do_train from d2go.trainer.lightning.training_loop import _do_test, _do_train
from detectron2.utils.file_io import PathManager from detectron2.utils.file_io import PathManager
from pytorch_lightning.callbacks import Callback, LearningRateMonitor, TQDMProgressBar from pytorch_lightning.callbacks import Callback, LearningRateMonitor, TQDMProgressBar
......
...@@ -7,13 +7,11 @@ Detection Training Script. ...@@ -7,13 +7,11 @@ Detection Training Script.
import logging import logging
import sys import sys
from dataclasses import dataclass from typing import List, Type, Union
from typing import Dict, List, Optional, Type, Union
import detectron2.utils.comm as comm import detectron2.utils.comm as comm
from d2go.config import CfgNode from d2go.config import CfgNode
from d2go.distributed import launch from d2go.distributed import launch
from d2go.evaluation.api import AccuracyDict, MetricsDict
from d2go.runner import BaseRunner from d2go.runner import BaseRunner
from d2go.setup import ( from d2go.setup import (
basic_argument_parser, basic_argument_parser,
...@@ -22,6 +20,7 @@ from d2go.setup import ( ...@@ -22,6 +20,7 @@ from d2go.setup import (
prepare_for_launch, prepare_for_launch,
setup_after_launch, setup_after_launch,
) )
from d2go.trainer.api import TrainNetOutput
from d2go.utils.misc import ( from d2go.utils.misc import (
dump_trained_model_configs, dump_trained_model_configs,
print_metrics_table, print_metrics_table,
...@@ -32,17 +31,6 @@ from detectron2.engine.defaults import create_ddp_model ...@@ -32,17 +31,6 @@ from detectron2.engine.defaults import create_ddp_model
logger = logging.getLogger("d2go.tools.train_net") logger = logging.getLogger("d2go.tools.train_net")
# TODO (T127368935) Split to TrainNetOutput and TestNetOutput
@dataclass
class TrainNetOutput:
accuracy: AccuracyDict[float]
metrics: MetricsDict[float]
# Optional, because we use None to distinguish "not used" from
# empty model configs. With T127368935, this should be reverted to dict.
model_configs: Optional[Dict[str, str]]
# TODO (T127368603): decide if `tensorboard_log_dir` should be part of output
tensorboard_log_dir: Optional[str] = None
def main( def main(
cfg: CfgNode, cfg: CfgNode,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment