maskrcnn_determinism_test.py 4.85 KB
Newer Older
Reed Wanderman-Milne's avatar
Reed Wanderman-Milne committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Test that Mask RCNN is deterministic when TF determinism is enabled."""

# pylint: disable=unused-import
from absl.testing import parameterized
import orbit
import tensorflow as tf

from official.core import exp_factory
from official.modeling import optimization
from official.vision.tasks import maskrcnn


class MaskRcnnTaskTest(parameterized.TestCase, tf.test.TestCase):

  def _edit_config_for_testing(self, config):
    # modify config to suit local testing
    config.trainer.steps_per_loop = 1
    config.task.train_data.global_batch_size = 2
    config.task.model.backbone.resnet.model_id = 18
    config.task.model.decoder.fpn.num_filters = 32
    config.task.model.detection_generator.pre_nms_top_k = 500
    config.task.model.detection_head.fc_dims = 128
    if config.task.model.include_mask:
      config.task.model.mask_sampler.num_sampled_masks = 10
      config.task.model.mask_head.num_convs = 1
    config.task.model.roi_generator.num_proposals = 100
    config.task.model.roi_generator.pre_nms_top_k = 150
    config.task.model.roi_generator.test_pre_nms_top_k = 150
    config.task.model.roi_generator.test_num_proposals = 100
    config.task.model.rpn_head.num_filters = 32
    config.task.model.roi_sampler.num_sampled_rois = 200
    config.task.model.input_size = [128, 128, 3]
    config.trainer.train_steps = 2
    config.task.train_data.shuffle_buffer_size = 2
    config.task.train_data.input_path = "coco/train-00000-of-00256.tfrecord"
    config.task.validation_data.global_batch_size = 2
    config.task.validation_data.input_path = "coco/val-00000-of-00032.tfrecord"

  def _build_and_run_model(self, config):
    task = maskrcnn.MaskRCNNTask(config.task)
    model = task.build_model()
    train_metrics = task.build_metrics(training=True)
    validation_metrics = task.build_metrics(training=False)

    strategy = tf.distribute.get_strategy()

    train_dataset = orbit.utils.make_distributed_dataset(
        strategy, task.build_inputs, config.task.train_data)
    train_iterator = iter(train_dataset)
    validation_dataset = orbit.utils.make_distributed_dataset(
        strategy, task.build_inputs, config.task.validation_data)
    validation_iterator = iter(validation_dataset)
    opt_factory = optimization.OptimizerFactory(config.trainer.optimizer_config)
    optimizer = opt_factory.build_optimizer(opt_factory.build_learning_rate())

    # Run training
    logs = task.train_step(next(train_iterator), model, optimizer,
                           metrics=train_metrics)
    for metric in train_metrics:
      logs[metric.name] = metric.result()

    # Run validation
    validation_logs = task.validation_step(next(validation_iterator), model,
                                           metrics=validation_metrics)
    for metric in validation_metrics:
      validation_logs[metric.name] = metric.result()

    return logs, validation_logs, model.weights

  @parameterized.parameters(
      "fasterrcnn_resnetfpn_coco",
      "maskrcnn_resnetfpn_coco",
      "maskrcnn_spinenet_coco",
      "cascadercnn_spinenet_coco",
  )
  def test_maskrcnn_task_train(self, test_config):
    """RetinaNet task test for training and val using toy configs."""
    config = exp_factory.get_exp_config(test_config)
    self._edit_config_for_testing(config)

    tf.keras.utils.set_random_seed(1)
    logs1, validation_logs1, weights1 = self._build_and_run_model(config)
    tf.keras.utils.set_random_seed(1)
    logs2, validation_logs2, weights2 = self._build_and_run_model(config)

    self.assertAllEqual(logs1["loss"], logs2["loss"])
    self.assertAllEqual(logs1["total_loss"], logs2["total_loss"])
    self.assertAllEqual(logs1["loss"], logs2["loss"])
    self.assertAllEqual(validation_logs1["coco_metric"][1]["detection_boxes"],
                        validation_logs2["coco_metric"][1]["detection_boxes"])
    self.assertAllEqual(validation_logs1["coco_metric"][1]["detection_scores"],
                        validation_logs2["coco_metric"][1]["detection_scores"])
    self.assertAllEqual(validation_logs1["coco_metric"][1]["detection_classes"],
                        validation_logs2["coco_metric"][1]["detection_classes"])
    for weight1, weight2 in zip(weights1, weights2):
      self.assertAllEqual(weight1, weight2)


if __name__ == "__main__":
  tf.config.experimental.enable_op_determinism()
  tf.test.main()