lightning_train_net.py 5.29 KB
Newer Older
facebook-github-bot's avatar
facebook-github-bot committed
1
2
3
4
5
6
7
8
9
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved


import logging
import os
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Type

10
import mobile_cv.torch.utils_pytorch.comm as comm
facebook-github-bot's avatar
facebook-github-bot committed
11
import pytorch_lightning as pl  # type: ignore
12
from d2go.config import CfgNode
13
from d2go.runner import create_runner
14
from d2go.runner.callbacks.quantization import QuantizationAwareTraining
facebook-github-bot's avatar
facebook-github-bot committed
15
from d2go.runner.lightning_task import GeneralizedRCNNTask
16
from d2go.setup import basic_argument_parser, setup_after_launch
17
from d2go.trainer.lightning.training_loop import _do_test, _do_train
18
from detectron2.utils.file_io import PathManager
19
from pytorch_lightning.callbacks import Callback, LearningRateMonitor, TQDMProgressBar
facebook-github-bot's avatar
facebook-github-bot committed
20
21
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
22
from pytorch_lightning.strategies.ddp import DDPStrategy
facebook-github-bot's avatar
facebook-github-bot committed
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
from torch.distributed import get_rank


logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("detectron2go.lightning.train_net")

FINAL_MODEL_CKPT = f"model_final{ModelCheckpoint.FILE_EXTENSION}"


@dataclass
class TrainOutput:
    output_dir: str
    accuracy: Optional[Dict[str, Any]] = None
    tensorboard_log_dir: Optional[str] = None
    model_configs: Optional[Dict[str, str]] = None


def _get_trainer_callbacks(cfg: CfgNode) -> List[Callback]:
    """Gets the trainer callbacks based on the given D2Go Config.

    Args:
        cfg: The normalized ConfigNode for this D2Go Task.

    Returns:
47
        A list of configured Callbacks to be used by the Lightning Trainer.
facebook-github-bot's avatar
facebook-github-bot committed
48
49
    """
    callbacks: List[Callback] = [
50
        TQDMProgressBar(refresh_rate=10),  # Arbitrary refresh_rate.
facebook-github-bot's avatar
facebook-github-bot committed
51
52
        LearningRateMonitor(logging_interval="step"),
        ModelCheckpoint(
53
            dirpath=cfg.OUTPUT_DIR,
facebook-github-bot's avatar
facebook-github-bot committed
54
55
56
            save_last=True,
        ),
    ]
Kai Zhang's avatar
Kai Zhang committed
57
58
    if cfg.QUANTIZATION.QAT.ENABLED:
        callbacks.append(QuantizationAwareTraining.from_config(cfg))
facebook-github-bot's avatar
facebook-github-bot committed
59
60
    return callbacks

Yanghan Wang's avatar
Yanghan Wang committed
61

62
63
64
65
def _get_strategy(cfg: CfgNode) -> DDPStrategy:
    return DDPStrategy(find_unused_parameters=cfg.MODEL.DDP_FIND_UNUSED_PARAMETERS)


Kai Zhang's avatar
Kai Zhang committed
66
def _get_accelerator(use_cpu: bool) -> str:
67
    return "cpu" if use_cpu else "gpu"
facebook-github-bot's avatar
facebook-github-bot committed
68

Kai Zhang's avatar
Kai Zhang committed
69

70
def get_trainer_params(cfg: CfgNode) -> Dict[str, Any]:
Kai Zhang's avatar
Kai Zhang committed
71
    use_cpu = cfg.MODEL.DEVICE.lower() == "cpu"
72
    strategy = _get_strategy(cfg)
73
74
    accelerator = _get_accelerator(use_cpu)

75
    return {
76
        "max_epochs": -1,
77
78
79
80
        "max_steps": cfg.SOLVER.MAX_ITER,
        "val_check_interval": cfg.TEST.EVAL_PERIOD
        if cfg.TEST.EVAL_PERIOD > 0
        else cfg.SOLVER.MAX_ITER,
81
82
        "num_nodes": comm.get_num_nodes(),
        "devices": comm.get_local_size(),
83
        "strategy": strategy,
84
        "accelerator": accelerator,
85
86
87
        "callbacks": _get_trainer_callbacks(cfg),
        "logger": TensorBoardLogger(save_dir=cfg.OUTPUT_DIR),
        "num_sanity_val_steps": 0,
Kai Zhang's avatar
Kai Zhang committed
88
        "replace_sampler_ddp": False,
89
    }
90

Yanghan Wang's avatar
Yanghan Wang committed
91

facebook-github-bot's avatar
facebook-github-bot committed
92
93
def main(
    cfg: CfgNode,
94
    output_dir: str,
facebook-github-bot's avatar
facebook-github-bot committed
95
96
97
98
99
100
101
102
103
104
    task_cls: Type[GeneralizedRCNNTask] = GeneralizedRCNNTask,
    eval_only: bool = False,
) -> TrainOutput:
    """Main function for launching a training with lightning trainer
    Args:
        cfg: D2go config node
        num_machines: Number of nodes used for distributed training
        num_processes: Number of processes on each node.
        eval_only: True if run evaluation only.
    """
105
    setup_after_launch(cfg, output_dir, task_cls)
facebook-github-bot's avatar
facebook-github-bot committed
106

107
    task = task_cls.from_config(cfg, eval_only)
108
    trainer_params = get_trainer_params(cfg)
facebook-github-bot's avatar
facebook-github-bot committed
109
110

    last_checkpoint = os.path.join(cfg.OUTPUT_DIR, "last.ckpt")
111
    if PathManager.exists(last_checkpoint):
facebook-github-bot's avatar
facebook-github-bot committed
112
113
114
115
116
        # resume training from checkpoint
        trainer_params["resume_from_checkpoint"] = last_checkpoint
        logger.info(f"Resuming training from checkpoint: {last_checkpoint}.")

    trainer = pl.Trainer(**trainer_params)
117
118
    model_configs = None
    if eval_only:
119
        _do_test(trainer, task)
120
    else:
121
        model_configs = _do_train(cfg, trainer, task)
122

facebook-github-bot's avatar
facebook-github-bot committed
123
124
    return TrainOutput(
        output_dir=cfg.OUTPUT_DIR,
125
        tensorboard_log_dir=trainer_params["logger"].log_dir,
facebook-github-bot's avatar
facebook-github-bot committed
126
        accuracy=task.eval_res,
127
        model_configs=model_configs,
facebook-github-bot's avatar
facebook-github-bot committed
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
    )


def build_config(
    config_file: str,
    task_cls: Type[GeneralizedRCNNTask],
    opts: Optional[List[str]] = None,
) -> CfgNode:
    """Build config node from config file
    Args:
        config_file: Path to a D2go config file
        output_dir: When given, this will override the OUTPUT_DIR in the config
        opts: A list of config overrides. e.g. ["SOLVER.IMS_PER_BATCH", "2"]
    """
    cfg = task_cls.get_default_cfg()
    cfg.merge_from_file(config_file)

    if opts:
        cfg.merge_from_list(opts)
    return cfg


def argument_parser():
    parser = basic_argument_parser(distributed=True, requires_output_dir=False)
    parser.add_argument(
        "--num-gpus", type=int, default=0, help="number of GPUs per machine"
    )
    return parser


if __name__ == "__main__":
    args = argument_parser().parse_args()
160
    task_cls = create_runner(args.runner) if args.runner else GeneralizedRCNNTask
facebook-github-bot's avatar
facebook-github-bot committed
161
162
163
164
165
166
167
168
169
    cfg = build_config(args.config_file, task_cls, args.opts)
    ret = main(
        cfg,
        args.output_dir,
        task_cls,
        eval_only=False,  # eval_only
    )
    if get_rank() == 0:
        print(ret)