yolo_model.py 3.04 KB
Newer Older
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
1
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
Abdullah Rashwan's avatar
Abdullah Rashwan committed
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Yolo models."""

Vishnu Banna's avatar
Vishnu Banna committed
17
from typing import Mapping, Union
Abdullah Rashwan's avatar
Abdullah Rashwan committed
18
import tensorflow as tf
Abdullah Rashwan's avatar
Abdullah Rashwan committed
19
from official.projects.yolo.modeling.layers import nn_blocks
Abdullah Rashwan's avatar
Abdullah Rashwan committed
20
21
22
23
24
25


class Yolo(tf.keras.Model):
  """The YOLO model class."""

  def __init__(self,
26
27
28
29
               backbone,
               decoder,
               head,
               detection_generator,
Abdullah Rashwan's avatar
Abdullah Rashwan committed
30
31
32
33
               **kwargs):
    """Detection initialization function.

    Args:
Vishnu Banna's avatar
Vishnu Banna committed
34
35
36
      backbone: `tf.keras.Model` a backbone network.
      decoder: `tf.keras.Model` a decoder network.
      head: `RetinaNetHead`, the RetinaNet head.
37
      detection_generator: the detection generator.
Abdullah Rashwan's avatar
Abdullah Rashwan committed
38
39
      **kwargs: keyword arguments to be passed.
    """
Vishnu Banna's avatar
Vishnu Banna committed
40
    super(Yolo, self).__init__(**kwargs)
Abdullah Rashwan's avatar
Abdullah Rashwan committed
41
42

    self._config_dict = {
Vishnu Banna's avatar
Vishnu Banna committed
43
44
45
46
        'backbone': backbone,
        'decoder': decoder,
        'head': head,
        'detection_generator': detection_generator
Abdullah Rashwan's avatar
Abdullah Rashwan committed
47
48
49
50
51
52
    }

    # model components
    self._backbone = backbone
    self._decoder = decoder
    self._head = head
Vishnu Banna's avatar
Vishnu Banna committed
53
    self._detection_generator = detection_generator
Vishnu Banna's avatar
Vishnu Banna committed
54
    self._fused = False
Vishnu Banna's avatar
Vishnu Banna committed
55
    return
Abdullah Rashwan's avatar
Abdullah Rashwan committed
56
57

  def call(self, inputs, training=False):
Vishnu Banna's avatar
Vishnu Banna committed
58
59
60
    maps = self.backbone(inputs)
    decoded_maps = self.decoder(maps)
    raw_predictions = self.head(decoded_maps)
Abdullah Rashwan's avatar
Abdullah Rashwan committed
61
    if training:
62
      return {'raw_output': raw_predictions}
Abdullah Rashwan's avatar
Abdullah Rashwan committed
63
64
    else:
      # Post-processing.
Vishnu Banna's avatar
Vishnu Banna committed
65
      predictions = self.detection_generator(raw_predictions)
66
      predictions.update({'raw_output': raw_predictions})
Abdullah Rashwan's avatar
Abdullah Rashwan committed
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
      return predictions

  @property
  def backbone(self):
    return self._backbone

  @property
  def decoder(self):
    return self._decoder

  @property
  def head(self):
    return self._head

  @property
Vishnu Banna's avatar
Vishnu Banna committed
82
83
  def detection_generator(self):
    return self._detection_generator
Abdullah Rashwan's avatar
Abdullah Rashwan committed
84
85
86
87
88
89
90

  def get_config(self):
    return self._config_dict

  @classmethod
  def from_config(cls, config):
    return cls(**config)
Vishnu Banna's avatar
Vishnu Banna committed
91

Vishnu Banna's avatar
Vishnu Banna committed
92
93
94
95
96
97
98
99
  @property
  def checkpoint_items(
      self) -> Mapping[str, Union[tf.keras.Model, tf.keras.layers.Layer]]:
    """Returns a dictionary of items to be additionally checkpointed."""
    items = dict(backbone=self.backbone, head=self.head)
    if self.decoder is not None:
      items.update(decoder=self.decoder)
    return items
100

Vishnu Banna's avatar
Vishnu Banna committed
101
  def fuse(self):
Vishnu Banna's avatar
Vishnu Banna committed
102
    """Fuses all Convolution and Batchnorm layers to get better latency."""
103
    print('Fusing Conv Batch Norm Layers.')
Vishnu Banna's avatar
Vishnu Banna committed
104
105
106
107
108
109
    if not self._fused:
      self._fused = True
      for layer in self.submodules:
        if isinstance(layer, nn_blocks.ConvBN):
          layer.fuse()
      self.summary()
110
    return