"torchvision/transforms/_functional_pil.py" did not exist on "96f6e0a117d5c56f7e0237851dbb96144ebb110b"
train_net.py 4.55 KB
Newer Older
facebook-github-bot's avatar
facebook-github-bot committed
1
2
3
4
5
6
7
8
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

"""
Detection Training Script.
"""

import logging
9
import sys
10
from typing import List, Type, Union
facebook-github-bot's avatar
facebook-github-bot committed
11
12

import detectron2.utils.comm as comm
13
from d2go.config import CfgNode
facebook-github-bot's avatar
facebook-github-bot committed
14
from d2go.distributed import launch
15
from d2go.runner import BaseRunner
facebook-github-bot's avatar
facebook-github-bot committed
16
17
from d2go.setup import (
    basic_argument_parser,
18
    build_basic_cli_args,
facebook-github-bot's avatar
facebook-github-bot committed
19
20
21
    post_mortem_if_fail_for_main,
    prepare_for_launch,
    setup_after_launch,
Tsahi Glik's avatar
Tsahi Glik committed
22
    setup_before_launch,
23
    setup_root_logger,
facebook-github-bot's avatar
facebook-github-bot committed
24
)
25
from d2go.trainer.api import TrainNetOutput
26
27
28
29
30
from d2go.utils.misc import (
    dump_trained_model_configs,
    print_metrics_table,
    save_binary_outputs,
)
31
from detectron2.engine.defaults import create_ddp_model
facebook-github-bot's avatar
facebook-github-bot committed
32
33
34
35


logger = logging.getLogger("d2go.tools.train_net")

36

facebook-github-bot's avatar
facebook-github-bot committed
37
def main(
38
39
40
41
42
    cfg: CfgNode,
    output_dir: str,
    runner_class: Union[str, Type[BaseRunner]],
    eval_only: bool = False,
    resume: bool = True,  # NOTE: always enable resume when running on cluster
43
) -> TrainNetOutput:
44
    runner = setup_after_launch(cfg, output_dir, runner_class)
facebook-github-bot's avatar
facebook-github-bot committed
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60

    model = runner.build_model(cfg)
    logger.info("Model:\n{}".format(model))

    if eval_only:
        checkpointer = runner.build_checkpointer(cfg, model, save_dir=output_dir)
        # checkpointer.resume_or_load() will skip all additional checkpointable
        # which may not be desired like ema states
        if resume and checkpointer.has_checkpoint():
            checkpoint = checkpointer.resume_or_load(cfg.MODEL.WEIGHTS, resume=resume)
        else:
            checkpoint = checkpointer.load(cfg.MODEL.WEIGHTS)
        train_iter = checkpoint.get("iteration", None)
        model.eval()
        metrics = runner.do_test(cfg, model, train_iter=train_iter)
        print_metrics_table(metrics)
61
62
63
64
65
        return TrainNetOutput(
            accuracy=metrics,
            model_configs={},
            metrics=metrics,
        )
facebook-github-bot's avatar
facebook-github-bot committed
66

67
68
69
70
71
72
73
    model = create_ddp_model(
        model,
        fp16_compression=cfg.MODEL.DDP_FP16_GRAD_COMPRESS,
        device_ids=None if cfg.MODEL.DEVICE == "cpu" else [comm.get_local_rank()],
        broadcast_buffers=False,
        find_unused_parameters=cfg.MODEL.DDP_FIND_UNUSED_PARAMETERS,
    )
facebook-github-bot's avatar
facebook-github-bot committed
74
75

    trained_cfgs = runner.do_train(cfg, model, resume=resume)
76
77
78
79
80
81
82
83

    final_eval = cfg.TEST.FINAL_EVAL
    if final_eval:
        # run evaluation after training in the same processes
        metrics = runner.do_test(cfg, model)
        print_metrics_table(metrics)
    else:
        metrics = {}
facebook-github-bot's avatar
facebook-github-bot committed
84
85
86

    # dump config files for trained models
    trained_model_configs = dump_trained_model_configs(cfg.OUTPUT_DIR, trained_cfgs)
87
    return TrainNetOutput(
facebook-github-bot's avatar
facebook-github-bot committed
88
        # for e2e_workflow
89
        accuracy=metrics,
facebook-github-bot's avatar
facebook-github-bot committed
90
        # for unit_workflow
91
92
93
        model_configs=trained_model_configs,
        metrics=metrics,
    )
facebook-github-bot's avatar
facebook-github-bot committed
94
95
96


def run_with_cmdline_args(args):
97
    cfg, output_dir, runner_name = prepare_for_launch(args)
Tsahi Glik's avatar
Tsahi Glik committed
98
    shared_context = setup_before_launch(cfg, output_dir, runner_name)
99

100
    main_func = main if args.disable_post_mortem else post_mortem_if_fail_for_main(main)
101
    outputs = launch(
102
        main_func,
facebook-github-bot's avatar
facebook-github-bot committed
103
104
105
106
107
        num_processes_per_machine=args.num_processes,
        num_machines=args.num_machines,
        machine_rank=args.machine_rank,
        dist_url=args.dist_url,
        backend=args.dist_backend,
Tsahi Glik's avatar
Tsahi Glik committed
108
        shared_context=shared_context,
109
110
111
112
113
        args=(cfg, output_dir, runner_name),
        kwargs={
            "eval_only": args.eval_only,
            "resume": args.resume,
        },
facebook-github-bot's avatar
facebook-github-bot committed
114
115
    )

116
117
118
    # Only save results from global rank 0 for consistency.
    if args.save_return_file is not None and args.machine_rank == 0:
        save_binary_outputs(args.save_return_file, outputs[0])
119

120

Tsahi Glik's avatar
Tsahi Glik committed
121
def cli(args=None):
facebook-github-bot's avatar
facebook-github-bot committed
122
123
124
125
126
127
128
129
130
    parser = basic_argument_parser(requires_output_dir=False)
    parser.add_argument(
        "--eval-only", action="store_true", help="perform evaluation only"
    )
    parser.add_argument(
        "--resume",
        action="store_true",
        help="whether to attempt to resume from the checkpoint directory",
    )
Tsahi Glik's avatar
Tsahi Glik committed
131
    args = sys.argv[1:] if args is None else args
132
    run_with_cmdline_args(parser.parse_args(args))
facebook-github-bot's avatar
facebook-github-bot committed
133

134

135
136
137
def build_cli_args(
    eval_only: bool = False,
    resume: bool = False,
138
    **kwargs,
139
) -> List[str]:
140
141
142
143
    """Returns parameters in the form of CLI arguments for train_net binary.

    For the list of non-train_net-specific parameters, see build_basic_cli_args."""
    args = build_basic_cli_args(**kwargs)
144
145
146
147
148
149
150
    if eval_only:
        args += ["--eval-only"]
    if resume:
        args += ["--resume"]
    return args


facebook-github-bot's avatar
facebook-github-bot committed
151
if __name__ == "__main__":
152
    setup_root_logger()
Tsahi Glik's avatar
Tsahi Glik committed
153
    cli()