lightning_train_net.py 6.2 KB
Newer Older
facebook-github-bot's avatar
facebook-github-bot committed
1
2
3
4
5
6
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved


import logging
import os
7
from typing import Any, Dict, List, Type, Union
facebook-github-bot's avatar
facebook-github-bot committed
8

9
import mobile_cv.torch.utils_pytorch.comm as comm
facebook-github-bot's avatar
facebook-github-bot committed
10
import pytorch_lightning as pl  # type: ignore
11
from d2go.config import CfgNode
12
from d2go.runner.callbacks.quantization import QuantizationAwareTraining
13
14
from d2go.runner.lightning_task import DefaultTask
from d2go.setup import basic_argument_parser, prepare_for_launch, setup_after_launch
15
from d2go.trainer.api import TestNetOutput, TrainNetOutput
16
from d2go.trainer.helper import parse_precision_from_string
17
from d2go.trainer.lightning.training_loop import _do_test, _do_train
18
from detectron2.utils.file_io import PathManager
19
from pytorch_lightning.callbacks import Callback, LearningRateMonitor, TQDMProgressBar
facebook-github-bot's avatar
facebook-github-bot committed
20
21
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
22
from pytorch_lightning.strategies.ddp import DDPStrategy
facebook-github-bot's avatar
facebook-github-bot committed
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
from torch.distributed import get_rank


logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("detectron2go.lightning.train_net")

FINAL_MODEL_CKPT = f"model_final{ModelCheckpoint.FILE_EXTENSION}"


def _get_trainer_callbacks(cfg: CfgNode) -> List[Callback]:
    """Gets the trainer callbacks based on the given D2Go Config.

    Args:
        cfg: The normalized ConfigNode for this D2Go Task.

    Returns:
39
        A list of configured Callbacks to be used by the Lightning Trainer.
facebook-github-bot's avatar
facebook-github-bot committed
40
41
    """
    callbacks: List[Callback] = [
42
        TQDMProgressBar(refresh_rate=10),  # Arbitrary refresh_rate.
facebook-github-bot's avatar
facebook-github-bot committed
43
44
        LearningRateMonitor(logging_interval="step"),
        ModelCheckpoint(
45
            dirpath=cfg.OUTPUT_DIR,
facebook-github-bot's avatar
facebook-github-bot committed
46
47
48
            save_last=True,
        ),
    ]
Kai Zhang's avatar
Kai Zhang committed
49
50
    if cfg.QUANTIZATION.QAT.ENABLED:
        callbacks.append(QuantizationAwareTraining.from_config(cfg))
facebook-github-bot's avatar
facebook-github-bot committed
51
52
    return callbacks

Yanghan Wang's avatar
Yanghan Wang committed
53

54
55
56
57
def _get_strategy(cfg: CfgNode) -> DDPStrategy:
    return DDPStrategy(find_unused_parameters=cfg.MODEL.DDP_FIND_UNUSED_PARAMETERS)


Kai Zhang's avatar
Kai Zhang committed
58
def _get_accelerator(use_cpu: bool) -> str:
59
    return "cpu" if use_cpu else "gpu"
facebook-github-bot's avatar
facebook-github-bot committed
60

Kai Zhang's avatar
Kai Zhang committed
61

62
def get_trainer_params(cfg: CfgNode) -> Dict[str, Any]:
Kai Zhang's avatar
Kai Zhang committed
63
    use_cpu = cfg.MODEL.DEVICE.lower() == "cpu"
64
    strategy = _get_strategy(cfg)
65
66
    accelerator = _get_accelerator(use_cpu)

67
    params = {
68
        "max_epochs": -1,
69
70
71
72
        "max_steps": cfg.SOLVER.MAX_ITER,
        "val_check_interval": cfg.TEST.EVAL_PERIOD
        if cfg.TEST.EVAL_PERIOD > 0
        else cfg.SOLVER.MAX_ITER,
73
74
        "num_nodes": comm.get_num_nodes(),
        "devices": comm.get_local_size(),
75
        "strategy": strategy,
76
        "accelerator": accelerator,
77
78
79
        "callbacks": _get_trainer_callbacks(cfg),
        "logger": TensorBoardLogger(save_dir=cfg.OUTPUT_DIR),
        "num_sanity_val_steps": 0,
Kai Zhang's avatar
Kai Zhang committed
80
        "replace_sampler_ddp": False,
81
82
83
        "precision": parse_precision_from_string(
            cfg.SOLVER.AMP.PRECISION, lightning=True
        )
84
85
        if cfg.SOLVER.AMP.ENABLED
        else 32,
86
    }
87
88
89
90
91
92
93
94
95
96
97
    if cfg.SOLVER.CLIP_GRADIENTS.ENABLED:
        if (
            cfg.SOLVER.CLIP_GRADIENTS.CLIP_TYPE.lower() == "norm"
            and cfg.SOLVER.CLIP_GRADIENTS.NORM_TYPE != 2.0
        ):
            raise ValueError(
                "D2Go Lightning backend supports only L2-norm for norm-based gradient clipping!"
            )
        params["gradient_clip_val"] = cfg.SOLVER.CLIP_GRADIENTS.CLIP_VALUE
        params["gradient_clip_algorithm"] = cfg.SOLVER.CLIP_GRADIENTS.CLIP_TYPE

98
99
100
101
102
103
104
105
106
107
108
109
110
    # Allow specifying additional trainer parameters under `LIGHTNING_TRAINER` field.
    # Please note that:
    #   - the `LIGHTNING_TRAINER`` is not part of "base" config, users need to add this to their default config via `_DEFAULTS_` or `get_default_cfg`.
    #   - this is a temporal solution due to future refactor of config system.
    if hasattr(cfg, "LIGHTNING_TRAINER"):
        params.update(
            {
                "reload_dataloaders_every_n_epochs": cfg.LIGHTNING_TRAINER.RELOAD_DATALOADERS_EVERY_N_EPOCHS,
                "sync_batchnorm": cfg.LIGHTNING_TRAINER.SYNC_BATCHNORM,
                "benchmark": cfg.LIGHTNING_TRAINER.BENCHMARK,
            }
        )

111
    return params
112

Yanghan Wang's avatar
Yanghan Wang committed
113

facebook-github-bot's avatar
facebook-github-bot committed
114
115
def main(
    cfg: CfgNode,
116
    output_dir: str,
117
    runner_class: Union[str, Type[DefaultTask]],
facebook-github-bot's avatar
facebook-github-bot committed
118
    eval_only: bool = False,
119
) -> Union[TrainNetOutput, TestNetOutput]:
facebook-github-bot's avatar
facebook-github-bot committed
120
121
122
123
124
125
126
    """Main function for launching a training with lightning trainer
    Args:
        cfg: D2go config node
        num_machines: Number of nodes used for distributed training
        num_processes: Number of processes on each node.
        eval_only: True if run evaluation only.
    """
127
    task_cls: Type[DefaultTask] = setup_after_launch(cfg, output_dir, runner_class)
facebook-github-bot's avatar
facebook-github-bot committed
128

129
    task = task_cls.from_config(cfg, eval_only)
130
    trainer_params = get_trainer_params(cfg)
facebook-github-bot's avatar
facebook-github-bot committed
131
132

    last_checkpoint = os.path.join(cfg.OUTPUT_DIR, "last.ckpt")
133
    if PathManager.exists(last_checkpoint):
facebook-github-bot's avatar
facebook-github-bot committed
134
135
136
137
138
        # resume training from checkpoint
        trainer_params["resume_from_checkpoint"] = last_checkpoint
        logger.info(f"Resuming training from checkpoint: {last_checkpoint}.")

    trainer = pl.Trainer(**trainer_params)
139

140
    if eval_only:
141
        _do_test(trainer, task)
142
143
144
145
146
        return TestNetOutput(
            tensorboard_log_dir=trainer_params["logger"].log_dir,
            accuracy=task.eval_res,
            metrics=task.eval_res,
        )
147
    else:
148
        model_configs = _do_train(cfg, trainer, task)
149
150
151
152
153
154
        return TrainNetOutput(
            tensorboard_log_dir=trainer_params["logger"].log_dir,
            accuracy=task.eval_res,
            metrics=task.eval_res,
            model_configs=model_configs,
        )
facebook-github-bot's avatar
facebook-github-bot committed
155
156
157
158


def argument_parser():
    parser = basic_argument_parser(distributed=True, requires_output_dir=False)
159
160
    # Change default runner argument
    parser.set_defaults(runner="d2go.runner.lightning_task.GeneralizedRCNNTask")
161
162
163
    parser.add_argument(
        "--eval-only", action="store_true", help="perform evaluation only"
    )
facebook-github-bot's avatar
facebook-github-bot committed
164
165
166
167
168
    return parser


if __name__ == "__main__":
    args = argument_parser().parse_args()
169
    cfg, output_dir, runner_name = prepare_for_launch(args)
170

facebook-github-bot's avatar
facebook-github-bot committed
171
172
    ret = main(
        cfg,
173
        output_dir,
174
        runner_name,
175
        eval_only=args.eval_only,
facebook-github-bot's avatar
facebook-github-bot committed
176
177
178
    )
    if get_rank() == 0:
        print(ret)