mask_rcnn_main.py 5.69 KB
Newer Older
zhenyi's avatar
zhenyi committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
#!/usr/bin/env python
# -*- coding: utf-8 -*-

# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Training script for Mask-RCNN."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'  # or any {'0', '1', '2'}
os.environ["TF_CPP_VMODULE"] = 'non_max_suppression_op=0,generate_box_proposals_op=0,executor=0'
# os.environ["TF_XLA_FLAGS"] = 'tf_xla_print_cluster_outputs=1'

from absl import app

import tensorflow as tf
from tensorflow.python.framework.ops import disable_eager_execution

from mask_rcnn.utils.logging_formatter import logging
from mask_rcnn.utils.distributed_utils import MPI_is_distributed

from mask_rcnn import dataloader
from mask_rcnn import distributed_executer
from mask_rcnn import mask_rcnn_model

from mask_rcnn.hyperparameters import mask_rcnn_params
from mask_rcnn.hyperparameters import params_io

from mask_rcnn.hyperparameters.cmdline_utils import define_hparams_flags

from mask_rcnn.utils.logging_formatter import log_cleaning
import dllogger

FLAGS = define_hparams_flags()


def run_executer(runtime_config, train_input_fn=None, eval_input_fn=None):
    """Runs Mask RCNN model on distribution strategy defined by the user."""

    if runtime_config.use_tf_distributed:
        executer = distributed_executer.TFDistributedExecuter(runtime_config, mask_rcnn_model.mask_rcnn_model_fn)
    else:
        executer = distributed_executer.EstimatorExecuter(runtime_config, mask_rcnn_model.mask_rcnn_model_fn)

    if runtime_config.mode == 'train':
        executer.train(
            train_input_fn=train_input_fn,
            run_eval_after_train=FLAGS.eval_after_training,
            eval_input_fn=eval_input_fn
        )

    elif runtime_config.mode == 'eval':
        executer.eval(eval_input_fn=eval_input_fn)

    elif runtime_config.mode == 'train_and_eval':
        executer.train_and_eval(train_input_fn=train_input_fn, eval_input_fn=eval_input_fn)

    else:
        raise ValueError('Mode must be one of `train`, `eval`, or `train_and_eval`')


def main(argv):
    del argv  # Unused.

    # ============================ Configure parameters ============================ #
    RUN_CONFIG = mask_rcnn_params.default_config()

    temp_config = FLAGS.flag_values_dict()
    temp_config['learning_rate_decay_levels'] = [float(decay) for decay in temp_config['learning_rate_decay_levels']]
    temp_config['learning_rate_levels'] = [
        decay * temp_config['init_learning_rate'] for decay in temp_config['learning_rate_decay_levels']
    ]
    temp_config['learning_rate_steps'] = [int(step) for step in temp_config['learning_rate_steps']]

    RUN_CONFIG = params_io.override_hparams(RUN_CONFIG, temp_config)
    # ============================ Configure parameters ============================ #
    if RUN_CONFIG.use_tf_distributed and MPI_is_distributed():
        raise RuntimeError("Incompatible Runtime. Impossible to use `--use_tf_distributed` with MPIRun Horovod")

    if RUN_CONFIG.mode in ('train', 'train_and_eval') and not RUN_CONFIG.eval_samples:
        raise RuntimeError('You must specify `training_file_pattern` for training.')

    if RUN_CONFIG.mode in ('eval', 'train_and_eval'):
        if not RUN_CONFIG.validation_file_pattern:
            raise RuntimeError('You must specify `validation_file_pattern` for evaluation.')

        if RUN_CONFIG.val_json_file == "" and not RUN_CONFIG.include_groundtruth_in_features:
            raise RuntimeError(
                'You must specify `val_json_file` or include_groundtruth_in_features=True for evaluation.')

        if not RUN_CONFIG.include_groundtruth_in_features and not os.path.isfile(RUN_CONFIG.val_json_file):
            raise FileNotFoundError("Validation JSON File not found: %s" % RUN_CONFIG.val_json_file)

    dllogger.init(backends=[dllogger.JSONStreamBackend(verbosity=dllogger.Verbosity.VERBOSE,
                                                           filename=RUN_CONFIG.log_path)])

    if RUN_CONFIG.mode in ('train', 'train_and_eval'):

        train_input_fn = dataloader.InputReader(
            file_pattern=RUN_CONFIG.training_file_pattern,
            mode=tf.estimator.ModeKeys.TRAIN,
            num_examples=None,
            use_fake_data=RUN_CONFIG.use_fake_data,
            use_instance_mask=RUN_CONFIG.include_mask,
            seed=RUN_CONFIG.seed
        )

    else:
        train_input_fn = None

    if RUN_CONFIG.mode in ('eval', 'train_and_eval' or (RUN_CONFIG.mode == 'train' and RUN_CONFIG.eval_after_training)):

        eval_input_fn = dataloader.InputReader(
            file_pattern=RUN_CONFIG.validation_file_pattern,
            mode=tf.estimator.ModeKeys.PREDICT,
            num_examples=RUN_CONFIG.eval_samples,
            use_fake_data=False,
            use_instance_mask=RUN_CONFIG.include_mask,
            seed=RUN_CONFIG.seed
        )

    else:
        eval_input_fn = None

    run_executer(RUN_CONFIG, train_input_fn, eval_input_fn)


if __name__ == '__main__':
    logging.set_verbosity(logging.INFO)
    disable_eager_execution()
    logging.set_verbosity(logging.DEBUG)
    tf.autograph.set_verbosity(0)
    log_cleaning(hide_deprecation_warnings=True)

    app.run(main)