train_net.py 7.38 KB
Newer Older
facebook-github-bot's avatar
facebook-github-bot committed
1
2
3
4
5
6
7
8
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

"""
Detection Training Script.
"""

import logging
9
import sys
10
from typing import List, Type, Union
facebook-github-bot's avatar
facebook-github-bot committed
11

12
import detectron2.utils.comm as comm
13
from d2go.config import CfgNode
facebook-github-bot's avatar
facebook-github-bot committed
14
from d2go.distributed import launch
15
from d2go.runner import BaseRunner
facebook-github-bot's avatar
facebook-github-bot committed
16
17
from d2go.setup import (
    basic_argument_parser,
18
    build_basic_cli_args,
facebook-github-bot's avatar
facebook-github-bot committed
19
20
21
    post_mortem_if_fail_for_main,
    prepare_for_launch,
    setup_after_launch,
Tsahi Glik's avatar
Tsahi Glik committed
22
    setup_before_launch,
23
    setup_root_logger,
facebook-github-bot's avatar
facebook-github-bot committed
24
)
25
from d2go.trainer.api import TestNetOutput, TrainNetOutput
26
from d2go.trainer.fsdp import is_fsdp_enabled
27
28
29
30
31
from d2go.utils.misc import (
    dump_trained_model_configs,
    print_metrics_table,
    save_binary_outputs,
)
32
from detectron2.engine.defaults import create_ddp_model
facebook-github-bot's avatar
facebook-github-bot committed
33

34
35
36
37
38
from torch.distributed.elastic.multiprocessing.errors import (
    _NOT_AVAILABLE,
    ChildFailedError,
    get_error_handler,
)
facebook-github-bot's avatar
facebook-github-bot committed
39
40

logger = logging.getLogger("d2go.tools.train_net")
41
42
43
# Make sure logging is set up centrally even for e.g. dataloading workers which
# have entry points outside of D2Go.
setup_root_logger()
facebook-github-bot's avatar
facebook-github-bot committed
44

45

facebook-github-bot's avatar
facebook-github-bot committed
46
def main(
47
48
49
50
51
    cfg: CfgNode,
    output_dir: str,
    runner_class: Union[str, Type[BaseRunner]],
    eval_only: bool = False,
    resume: bool = True,  # NOTE: always enable resume when running on cluster
52
) -> Union[TrainNetOutput, TestNetOutput]:
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
    logger.info("Starting main")
    error_handler = get_error_handler()
    logger.debug(f">>>>>>> Error handler is: {type(error_handler)=}, {error_handler=}")
    error_handler.initialize()
    logger.debug("Error handler has been initialized")

    try:  # Main error handler starts here...
        logger.debug(f"Entered main for d2go, {runner_class=}")
        runner = setup_after_launch(cfg, output_dir, runner_class)

        model = runner.build_model(cfg)
        logger.info("Model:\n{}".format(model))

        if eval_only:
            checkpointer = runner.build_checkpointer(cfg, model, save_dir=output_dir)
            # checkpointer.resume_or_load() will skip all additional checkpointable
            # which may not be desired like ema states
            if resume and checkpointer.has_checkpoint():
                checkpoint = checkpointer.resume_or_load(
                    cfg.MODEL.WEIGHTS, resume=resume
                )
            else:
                checkpoint = checkpointer.load(cfg.MODEL.WEIGHTS)
            train_iter = checkpoint.get("iteration", None)
            model.eval()
            metrics = runner.do_test(cfg, model, train_iter=train_iter)
            print_metrics_table(metrics)
            return TestNetOutput(
                accuracy=metrics,
                metrics=metrics,
            )

        # Use DDP if FSDP is not enabled
        # TODO (T142223289): rewrite ddp wrapping as modeling hook
        if not is_fsdp_enabled(cfg):
            model = create_ddp_model(
                model,
                fp16_compression=cfg.MODEL.DDP_FP16_GRAD_COMPRESS,
                device_ids=None
                if cfg.MODEL.DEVICE == "cpu"
                else [comm.get_local_rank()],
                broadcast_buffers=False,
                find_unused_parameters=cfg.MODEL.DDP_FIND_UNUSED_PARAMETERS,
            )

        logger.info("Starting train..")
        trained_cfgs = runner.do_train(cfg, model, resume=resume)

        final_eval = cfg.TEST.FINAL_EVAL
        if final_eval:
            # run evaluation after training in the same processes
            metrics = runner.do_test(cfg, model)
            print_metrics_table(metrics)
facebook-github-bot's avatar
facebook-github-bot committed
106
        else:
107
108
109
110
111
112
            metrics = {}

        # dump config files for trained models
        trained_model_configs = dump_trained_model_configs(cfg.OUTPUT_DIR, trained_cfgs)
        return TrainNetOutput(
            # for e2e_workflow
113
            accuracy=metrics,
114
115
            # for unit_workflow
            model_configs=trained_model_configs,
116
117
            metrics=metrics,
        )
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
    except ChildFailedError as e:
        logger.info(f"Got a ChildFailedError: {e=}")
        rank, failure = e.get_first_failure()
        if failure.error_file != _NOT_AVAILABLE:
            error_handler.dump_error_file(failure.error_file, failure.exitcode)
        else:
            logger.info(
                (
                    f"local_rank {rank} FAILED with no error file."
                    f" Decorate your entrypoint fn with @record for traceback info."
                    f" See: https://pytorch.org/docs/stable/elastic/errors.html"
                )
            )
            raise
    except Exception as e:
        logger.info(f"Caught a generic exception: {e=}")
        error_handler.record_exception(e)
        raise
facebook-github-bot's avatar
facebook-github-bot committed
136
137
138


def run_with_cmdline_args(args):
139
    cfg, output_dir, runner_name = prepare_for_launch(args)
Tsahi Glik's avatar
Tsahi Glik committed
140
    shared_context = setup_before_launch(cfg, output_dir, runner_name)
141

142
    main_func = main if args.disable_post_mortem else post_mortem_if_fail_for_main(main)
143
    outputs = launch(
144
        main_func,
facebook-github-bot's avatar
facebook-github-bot committed
145
146
147
148
149
        num_processes_per_machine=args.num_processes,
        num_machines=args.num_machines,
        machine_rank=args.machine_rank,
        dist_url=args.dist_url,
        backend=args.dist_backend,
Tsahi Glik's avatar
Tsahi Glik committed
150
        shared_context=shared_context,
151
152
153
154
155
        args=(cfg, output_dir, runner_name),
        kwargs={
            "eval_only": args.eval_only,
            "resume": args.resume,
        },
facebook-github-bot's avatar
facebook-github-bot committed
156
157
    )

158
159
160
    # Only save results from global rank 0 for consistency.
    if args.save_return_file is not None and args.machine_rank == 0:
        save_binary_outputs(args.save_return_file, outputs[0])
161

162

Tsahi Glik's avatar
Tsahi Glik committed
163
def cli(args=None):
164
    logger.info(f"Inside CLI, {args=}")
facebook-github-bot's avatar
facebook-github-bot committed
165
166
167
168
169
170
171
172
173
    parser = basic_argument_parser(requires_output_dir=False)
    parser.add_argument(
        "--eval-only", action="store_true", help="perform evaluation only"
    )
    parser.add_argument(
        "--resume",
        action="store_true",
        help="whether to attempt to resume from the checkpoint directory",
    )
Tsahi Glik's avatar
Tsahi Glik committed
174
    args = sys.argv[1:] if args is None else args
175
    run_with_cmdline_args(parser.parse_args(args))
facebook-github-bot's avatar
facebook-github-bot committed
176

177

178
179
180
def build_cli_args(
    eval_only: bool = False,
    resume: bool = False,
181
    **kwargs,
182
) -> List[str]:
183
184
185
186
    """Returns parameters in the form of CLI arguments for train_net binary.

    For the list of non-train_net-specific parameters, see build_basic_cli_args."""
    args = build_basic_cli_args(**kwargs)
187
188
189
190
191
192
193
    if eval_only:
        args += ["--eval-only"]
    if resume:
        args += ["--resume"]
    return args


facebook-github-bot's avatar
facebook-github-bot committed
194
if __name__ == "__main__":
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
    logger.info("Starting CLI application")
    try:
        cli()
    finally:
        logging.info("Entering final reply file generation step")
        import glob
        import os
        import shutil

        torchx_reply_files = glob.glob("/tmp/torchx_*/**/*.json", recursive=True)
        logger.info(
            f"Found the following reply files on this host: {torchx_reply_files}"
        )
        first_reply_file = None
        first_reply_file_st = float("Inf")
        for f in torchx_reply_files:
            if (mtime := os.stat(f).st_mtime) < first_reply_file_st:
                first_reply_file = f
                first_reply_file_st = mtime
        if first_reply_file and os.environ.get("MAST_HPC_TASK_FAILURE_REPLY_FILE"):
            logger.info(
                f'Copying {first_reply_file=} to {os.environ["MAST_HPC_TASK_FAILURE_REPLY_FILE"]}'
            )
            shutil.copyfile(
                first_reply_file, os.environ["MAST_HPC_TASK_FAILURE_REPLY_FILE"]
            )