"examples/pytorch/dgmg/util.py" did not exist on "6105e441426f97f31d96c54d6f35830028c2b3f6"
train_net.py 6.51 KB
Newer Older
facebook-github-bot's avatar
facebook-github-bot committed
1
2
3
4
5
6
7
8
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

"""
Detection Training Script.
"""

import logging
9
import sys
10
from typing import Callable, Dict, List, Type, Union
facebook-github-bot's avatar
facebook-github-bot committed
11

12
import detectron2.utils.comm as comm
13
from d2go.config import CfgNode
14
from d2go.distributed import distributed_worker, launch
15
from d2go.runner import BaseRunner
16
from d2go.runner.config_defaults import preprocess_cfg
facebook-github-bot's avatar
facebook-github-bot committed
17
18
from d2go.setup import (
    basic_argument_parser,
19
    build_basic_cli_args,
facebook-github-bot's avatar
facebook-github-bot committed
20
21
22
    post_mortem_if_fail_for_main,
    prepare_for_launch,
    setup_after_launch,
Tsahi Glik's avatar
Tsahi Glik committed
23
    setup_before_launch,
24
    setup_root_logger,
facebook-github-bot's avatar
facebook-github-bot committed
25
)
26
from d2go.trainer.api import TestNetOutput, TrainNetOutput
27
from d2go.trainer.fsdp import is_fsdp_enabled
28
from d2go.utils.mast import gather_mast_errors, mast_error_handler
29
30
31
32
33
from d2go.utils.misc import (
    dump_trained_model_configs,
    print_metrics_table,
    save_binary_outputs,
)
34
from detectron2.engine.defaults import create_ddp_model
35
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
facebook-github-bot's avatar
facebook-github-bot committed
36
37

logger = logging.getLogger("d2go.tools.train_net")
38
39
40
# Make sure logging is set up centrally even for e.g. dataloading workers which
# have entry points outside of D2Go.
setup_root_logger()
facebook-github-bot's avatar
facebook-github-bot committed
41

42

43
44
45
TrainOrTestNetOutput = Union[TrainNetOutput, TestNetOutput]


facebook-github-bot's avatar
facebook-github-bot committed
46
def main(
47
48
49
50
51
    cfg: CfgNode,
    output_dir: str,
    runner_class: Union[str, Type[BaseRunner]],
    eval_only: bool = False,
    resume: bool = True,  # NOTE: always enable resume when running on cluster
52
) -> TrainOrTestNetOutput:
53
54
    logger.debug(f"Entered main for d2go, {runner_class=}")
    runner = setup_after_launch(cfg, output_dir, runner_class)
55

56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
    model = runner.build_model(cfg)
    logger.info("Model:\n{}".format(model))

    if eval_only:
        checkpointer = runner.build_checkpointer(cfg, model, save_dir=output_dir)
        # checkpointer.resume_or_load() will skip all additional checkpointable
        # which may not be desired like ema states
        if resume and checkpointer.has_checkpoint():
            checkpoint = checkpointer.resume_or_load(cfg.MODEL.WEIGHTS, resume=resume)
        else:
            checkpoint = checkpointer.load(cfg.MODEL.WEIGHTS)
        train_iter = checkpoint.get("iteration", None)
        model.eval()
        metrics = runner.do_test(cfg, model, train_iter=train_iter)
        print_metrics_table(metrics)
        return TestNetOutput(
72
73
74
            accuracy=metrics,
            metrics=metrics,
        )
75
76
77

    # Use DDP if FSDP is not enabled
    # TODO (T142223289): rewrite ddp wrapping as modeling hook
78
    if not isinstance(model, FSDP):
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
        model = create_ddp_model(
            model,
            fp16_compression=cfg.MODEL.DDP_FP16_GRAD_COMPRESS,
            device_ids=None if cfg.MODEL.DEVICE == "cpu" else [comm.get_local_rank()],
            broadcast_buffers=False,
            find_unused_parameters=cfg.MODEL.DDP_FIND_UNUSED_PARAMETERS,
            gradient_as_bucket_view=cfg.MODEL.DDP_GRADIENT_AS_BUCKET_VIEW,
        )

    logger.info("Starting train..")
    trained_cfgs = runner.do_train(cfg, model, resume=resume)

    final_eval = cfg.TEST.FINAL_EVAL
    if final_eval:
        # run evaluation after training in the same processes
        metrics = runner.do_test(cfg, model)
        print_metrics_table(metrics)
    else:
        metrics = {}

    # dump config files for trained models
    trained_model_configs = dump_trained_model_configs(cfg.OUTPUT_DIR, trained_cfgs)
    return TrainNetOutput(
        # for e2e_workflow
        accuracy=metrics,
        # for unit_workflow
        model_configs=trained_model_configs,
        metrics=metrics,
    )


110
def wrapped_main(*args, **kwargs) -> Callable[..., TrainOrTestNetOutput]:
111
    return mast_error_handler(main)(*args, **kwargs)
facebook-github-bot's avatar
facebook-github-bot committed
112
113
114


def run_with_cmdline_args(args):
115
    cfg, output_dir, runner_name = prepare_for_launch(args)
116
    cfg = preprocess_cfg(cfg)
Tsahi Glik's avatar
Tsahi Glik committed
117
    shared_context = setup_before_launch(cfg, output_dir, runner_name)
118

119
120
121
122
123
    main_func = (
        wrapped_main
        if args.disable_post_mortem
        else post_mortem_if_fail_for_main(wrapped_main)
    )
124
125
126

    if args.run_as_worker:
        logger.info("Running as worker")
127
        result: TrainOrTestNetOutput = distributed_worker(
128
129
130
131
132
133
134
135
136
137
138
139
140
            main_func,
            args=(cfg, output_dir, runner_name),
            kwargs={
                "eval_only": args.eval_only,
                "resume": args.resume,
            },
            backend=args.dist_backend,
            init_method=None,  # init_method is env by default
            dist_params=None,
            return_save_file=None,
            shared_context=shared_context,
        )
    else:
141
        outputs: Dict[int, TrainOrTestNetOutput] = launch(
142
143
144
145
146
147
148
149
150
151
152
153
154
            main_func,
            num_processes_per_machine=args.num_processes,
            num_machines=args.num_machines,
            machine_rank=args.machine_rank,
            dist_url=args.dist_url,
            backend=args.dist_backend,
            shared_context=shared_context,
            args=(cfg, output_dir, runner_name),
            kwargs={
                "eval_only": args.eval_only,
                "resume": args.resume,
            },
        )
155
156
157
        # The indices of outputs are global ranks of all workers on this node, here we
        # use the local master result.
        result: TrainOrTestNetOutput = outputs[args.machine_rank * args.num_processes]
facebook-github-bot's avatar
facebook-github-bot committed
158

159
    # Only save result from global rank 0 for consistency.
160
    if args.save_return_file is not None and args.machine_rank == 0:
161
162
163
        logger.info(f"Operator result: {result}")
        logger.info(f"Writing result to {args.save_return_file}.")
        save_binary_outputs(args.save_return_file, result)
164

165

Tsahi Glik's avatar
Tsahi Glik committed
166
def cli(args=None):
167
    logger.info(f"Inside CLI, {args=}")
facebook-github-bot's avatar
facebook-github-bot committed
168
169
170
171
172
173
174
175
176
    parser = basic_argument_parser(requires_output_dir=False)
    parser.add_argument(
        "--eval-only", action="store_true", help="perform evaluation only"
    )
    parser.add_argument(
        "--resume",
        action="store_true",
        help="whether to attempt to resume from the checkpoint directory",
    )
Tsahi Glik's avatar
Tsahi Glik committed
177
    args = sys.argv[1:] if args is None else args
178
    run_with_cmdline_args(parser.parse_args(args))
facebook-github-bot's avatar
facebook-github-bot committed
179

180

181
182
183
def build_cli_args(
    eval_only: bool = False,
    resume: bool = False,
184
    **kwargs,
185
) -> List[str]:
186
187
188
189
    """Returns parameters in the form of CLI arguments for train_net binary.

    For the list of non-train_net-specific parameters, see build_basic_cli_args."""
    args = build_basic_cli_args(**kwargs)
190
191
192
193
194
195
196
    if eval_only:
        args += ["--eval-only"]
    if resume:
        args += ["--resume"]
    return args


facebook-github-bot's avatar
facebook-github-bot committed
197
if __name__ == "__main__":
198
    gather_mast_errors(cli())