segmentation_model.py 3.27 KB
Newer Older
Yeqing Li's avatar
Yeqing Li committed
1
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Abdullah Rashwan's avatar
Abdullah Rashwan committed
2
3
4
5
6
7
8
9
10
11
12
13
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Yeqing Li's avatar
Yeqing Li committed
14

Abdullah Rashwan's avatar
Abdullah Rashwan committed
15
"""Build segmentation models."""
Abdullah Rashwan's avatar
Abdullah Rashwan committed
16
from typing import Any, Mapping, Union, Optional, Dict
Abdullah Rashwan's avatar
Abdullah Rashwan committed
17
18
19
20
21
22
23
24
25
26
27
28
29

# Import libraries
import tensorflow as tf

layers = tf.keras.layers


@tf.keras.utils.register_keras_serializable(package='Vision')
class SegmentationModel(tf.keras.Model):
  """A Segmentation class model.

  Input images are passed through backbone first. Decoder network is then
  applied, and finally, segmentation head is applied on the output of the
Abdullah Rashwan's avatar
Abdullah Rashwan committed
30
31
32
33
34
  decoder network. Layers such as ASPP should be part of decoder. Any feature
  fusion is done as part of the segmentation head (i.e. deeplabv3+ feature
  fusion is not part of the decoder, instead it is part of the segmentation
  head). This way, different feature fusion techniques can be combined with
  different backbones, and decoders.
Abdullah Rashwan's avatar
Abdullah Rashwan committed
35
36
  """

Fan Yang's avatar
Fan Yang committed
37
  def __init__(self, backbone: tf.keras.Model, decoder: tf.keras.Model,
Abdullah Rashwan's avatar
Abdullah Rashwan committed
38
39
40
               head: tf.keras.layers.Layer,
               mask_scoring_head: Optional[tf.keras.layers.Layer] = None,
               **kwargs):
Abdullah Rashwan's avatar
Abdullah Rashwan committed
41
42
43
44
45
46
    """Segmentation initialization function.

    Args:
      backbone: a backbone network.
      decoder: a decoder network. E.g. FPN.
      head: segmentation head.
Abdullah Rashwan's avatar
Abdullah Rashwan committed
47
      mask_scoring_head: mask scoring head.
Abdullah Rashwan's avatar
Abdullah Rashwan committed
48
49
50
51
52
53
54
      **kwargs: keyword arguments to be passed.
    """
    super(SegmentationModel, self).__init__(**kwargs)
    self._config_dict = {
        'backbone': backbone,
        'decoder': decoder,
        'head': head,
Abdullah Rashwan's avatar
Abdullah Rashwan committed
55
        'mask_scoring_head': mask_scoring_head,
Abdullah Rashwan's avatar
Abdullah Rashwan committed
56
57
58
59
    }
    self.backbone = backbone
    self.decoder = decoder
    self.head = head
Abdullah Rashwan's avatar
Abdullah Rashwan committed
60
    self.mask_scoring_head = mask_scoring_head
Abdullah Rashwan's avatar
Abdullah Rashwan committed
61

Abdullah Rashwan's avatar
Abdullah Rashwan committed
62
63
  def call(self, inputs: tf.Tensor, training: bool = None
           ) -> Dict[str, tf.Tensor]:
Abdullah Rashwan's avatar
Abdullah Rashwan committed
64
    backbone_features = self.backbone(inputs)
Abdullah Rashwan's avatar
Abdullah Rashwan committed
65
66

    if self.decoder:
Abdullah Rashwan's avatar
Abdullah Rashwan committed
67
68
69
70
      decoder_features = self.decoder(backbone_features)
    else:
      decoder_features = backbone_features

Abdullah Rashwan's avatar
Abdullah Rashwan committed
71
72
73
74
75
76
    logits = self.head((backbone_features, decoder_features))
    outputs = {'logits': logits}
    if self.mask_scoring_head:
      mask_scores = self.mask_scoring_head(logits)
      outputs.update({'mask_scores': mask_scores})
    return outputs
Abdullah Rashwan's avatar
Abdullah Rashwan committed
77
78

  @property
Fan Yang's avatar
Fan Yang committed
79
80
  def checkpoint_items(
      self) -> Mapping[str, Union[tf.keras.Model, tf.keras.layers.Layer]]:
Abdullah Rashwan's avatar
Abdullah Rashwan committed
81
82
83
84
    """Returns a dictionary of items to be additionally checkpointed."""
    items = dict(backbone=self.backbone, head=self.head)
    if self.decoder is not None:
      items.update(decoder=self.decoder)
Abdullah Rashwan's avatar
Abdullah Rashwan committed
85
86
    if self.mask_scoring_head is not None:
      items.update(mask_scoring_head=self.mask_scoring_head)
Abdullah Rashwan's avatar
Abdullah Rashwan committed
87
88
    return items

Fan Yang's avatar
Fan Yang committed
89
  def get_config(self) -> Mapping[str, Any]:
Abdullah Rashwan's avatar
Abdullah Rashwan committed
90
91
92
93
94
    return self._config_dict

  @classmethod
  def from_config(cls, config, custom_objects=None):
    return cls(**config)