Commit d353b5af authored by Yanghan Wang's avatar Yanghan Wang Committed by Facebook GitHub Bot
Browse files

use the same prepare_for_launch for lightning

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/319

follow up on D37500599 (https://github.com/facebookresearch/d2go/commit/668b7ac29b0afb55d5923e72fe4f6428e5c85cbd), move lightning_train_net part of D37367360 to this diff.

Reviewed By: sstsai-adl

Differential Revision: D37534370

fbshipit-source-id: 7f48942a14ce16a9a9540b189441b540ce4f4b25
parent 4208a791
...@@ -10,10 +10,9 @@ from typing import Any, Dict, List, Optional, Type, Union ...@@ -10,10 +10,9 @@ from typing import Any, Dict, List, Optional, Type, Union
import mobile_cv.torch.utils_pytorch.comm as comm import mobile_cv.torch.utils_pytorch.comm as comm
import pytorch_lightning as pl # type: ignore import pytorch_lightning as pl # type: ignore
from d2go.config import CfgNode from d2go.config import CfgNode
from d2go.runner import create_runner
from d2go.runner.callbacks.quantization import QuantizationAwareTraining from d2go.runner.callbacks.quantization import QuantizationAwareTraining
from d2go.runner.lightning_task import DefaultTask, GeneralizedRCNNTask from d2go.runner.lightning_task import DefaultTask
from d2go.setup import basic_argument_parser, setup_after_launch from d2go.setup import basic_argument_parser, prepare_for_launch, setup_after_launch
from d2go.trainer.lightning.training_loop import _do_test, _do_train from d2go.trainer.lightning.training_loop import _do_test, _do_train
from detectron2.utils.file_io import PathManager from detectron2.utils.file_io import PathManager
from pytorch_lightning.callbacks import Callback, LearningRateMonitor, TQDMProgressBar from pytorch_lightning.callbacks import Callback, LearningRateMonitor, TQDMProgressBar
...@@ -128,30 +127,10 @@ def main( ...@@ -128,30 +127,10 @@ def main(
) )
def build_config(
config_file: str,
task_cls: Type[DefaultTask],
opts: Optional[List[str]] = None,
) -> CfgNode:
"""Build config node from config file
Args:
config_file: Path to a D2go config file
output_dir: When given, this will override the OUTPUT_DIR in the config
opts: A list of config overrides. e.g. ["SOLVER.IMS_PER_BATCH", "2"]
"""
cfg = task_cls.get_default_cfg()
cfg.merge_from_file(config_file)
if opts:
cfg.merge_from_list(opts)
return cfg
def argument_parser(): def argument_parser():
parser = basic_argument_parser(distributed=True, requires_output_dir=False) parser = basic_argument_parser(distributed=True, requires_output_dir=False)
parser.add_argument( # Change default runner argument
"--num-gpus", type=int, default=0, help="number of GPUs per machine" parser.set_defaults(runner="d2go.runner.lightning_task.GeneralizedRCNNTask")
)
parser.add_argument( parser.add_argument(
"--eval-only", action="store_true", help="perform evaluation only" "--eval-only", action="store_true", help="perform evaluation only"
) )
...@@ -160,16 +139,12 @@ def argument_parser(): ...@@ -160,16 +139,12 @@ def argument_parser():
if __name__ == "__main__": if __name__ == "__main__":
args = argument_parser().parse_args() args = argument_parser().parse_args()
task_cls = create_runner(args.runner) if args.runner else GeneralizedRCNNTask cfg, output_dir, runner_name = prepare_for_launch(args)
cfg = build_config(args.config_file, task_cls, args.opts)
assert args.output_dir or args.config_file
output_dir = args.output_dir or cfg.OUTPUT_DIR
ret = main( ret = main(
cfg, cfg,
output_dir, output_dir,
task_cls, runner_name,
eval_only=args.eval_only, eval_only=args.eval_only,
) )
if get_rank() == 0: if get_rank() == 0:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment