Commit a04d9e0e authored by Vishnu Banna's avatar Vishnu Banna
Browse files

merged

parents 64f16d61 bcbce005
......@@ -26,7 +26,12 @@ from official.vision.beta.configs import backbones
@dataclasses.dataclass
class Darknet(hyperparams.Config):
"""Darknet config."""
model_id: str = "darknet53"
model_id: str = 'darknet53'
width_scale: float = 1.0
depth_scale: float = 1.0
dilate: bool = False
min_level: int = 3
max_level: int = 5
@dataclasses.dataclass
......
......@@ -13,7 +13,6 @@
# limitations under the License.
# Lint as: python3
"""Contains definitions of Darknet Backbone Networks.
These models are inspired by ResNet and CSPNet.
......@@ -46,16 +45,14 @@ from official.vision.beta.projects.yolo.modeling.layers import nn_blocks
# builder required classes
class BlockConfig:
"""
This is a class to store layer config to make code more readable.
"""
"""Class to store layer config to make code more readable."""
def __init__(self, layer, stack, reps, bottleneck, filters, pool_size,
kernel_size, strides, padding, activation, route, dilation_rate,
output_name, is_output):
"""
"""Initializing method for BlockConfig.
Args:
layer: A `str` for layer name.
stack: A `str` for the type of layer ordering to use for this specific
......@@ -69,7 +66,7 @@ class BlockConfig:
padding: An `int` for the padding to apply to layers in this stack.
activation: A `str` for the activation to use for this stack.
route: An `int` for the level to route from to get the next input.
dilation_rate: An `int` for the scale used in dilated Darknet.
dilation_rate: An `int` for the scale used in dialated Darknet.
output_name: A `str` for the name to use for this output.
is_output: A `bool` for whether this layer is an output in the default
model.
......@@ -98,11 +95,11 @@ def build_block_specs(config):
class LayerBuilder:
"""
This is a class that is used for quick look up of default layers used
by darknet to connect, introduce or exit a level. Used in place of an
if condition or switch to make adding new layers easier and to reduce
redundant code.
"""Layer builder class.
Class for quick look up of default layers used by darknet to
connect, introduce or exit a level. Used in place of an if condition
or switch to make adding new layers easier and to reduce redundant code.
"""
def __init__(self):
......@@ -378,7 +375,7 @@ BACKBONES = {
@tf.keras.utils.register_keras_serializable(package='yolo')
class Darknet(tf.keras.Model):
""" The Darknet backbone architecture. """
"""The Darknet backbone architecture."""
def __init__(
self,
......@@ -596,8 +593,8 @@ class Darknet(tf.keras.Model):
filters=config.filters, downsample=True, **self._default_dict)(
inputs)
dilated_reps = config.repetitions - \
(self._default_dict['dilation_rate'] // 2) - 1
dilated_reps = config.repetitions - (
self._default_dict['dilation_rate'] // 2) - 1
for i in range(dilated_reps):
self._default_dict['name'] = f'{name}_{i}'
x = nn_blocks.DarkResidual(
......@@ -668,14 +665,13 @@ def build_darknet(
backbone_config: hyperparams.Config,
norm_activation_config: hyperparams.Config,
l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model:
"""Builds darknet."""
backbone_cfg = model_config.backbone.get()
norm_activation_config = model_config.norm_activation
backbone_cfg = backbone_config.get()
model = Darknet(
model_id=backbone_cfg.model_id,
min_level=model_config.min_level,
max_level=model_config.max_level,
min_level=backbone_cfg.min_level,
max_level=backbone_cfg.max_level,
input_specs=input_specs,
dilate=backbone_cfg.dilate,
width_scale=backbone_cfg.width_scale,
......@@ -686,4 +682,4 @@ def build_darknet(
norm_epsilon=norm_activation_config.norm_epsilon,
kernel_regularizer=l2_regularizer)
model.summary()
return model
return model
\ No newline at end of file
......@@ -13,7 +13,7 @@
# limitations under the License.
# Lint as: python3
"""Tests for YOLO."""
"""Tests for yolo."""
from absl.testing import parameterized
import numpy as np
......@@ -125,6 +125,5 @@ class DarknetTest(parameterized.TestCase, tf.test.TestCase):
# If the serialization was successful, the new config should match the old.
self.assertAllEqual(network.get_config(), new_network.get_config())
if __name__ == '__main__':
tf.test.main()
tf.test.main()
\ No newline at end of file
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
......@@ -22,15 +22,7 @@ from official.vision.beta.projects.yolo.modeling.layers import nn_blocks
@tf.keras.utils.register_keras_serializable(package='yolo')
class _IdentityRoute(tf.keras.layers.Layer):
def __init__(self, **kwargs):
"""
Private class to mirror the outputs of blocks in nn_blocks for an easier
programatic generation of the feature pyramid network.
"""
super().__init__(**kwargs)
def call(self, inputs): # pylint: disable=arguments-differ
def call(self, inputs):
return None, inputs
......@@ -111,8 +103,7 @@ class YoloFPN(tf.keras.layers.Layer):
return list(reversed(depths))
def build(self, inputs):
"""Use config dictionary to generate all important attributes for head
construction.
"""Use config dictionary to generate all important attributes for head.
Args:
inputs: dictionary of the shape of input args as a dictionary of lists.
......@@ -127,7 +118,7 @@ class YoloFPN(tf.keras.layers.Layer):
# directly connect to an input path and process it
self.preprocessors = dict()
# resample an input and merge it with the output of another path
# in order to aggregate backbone outputs
# inorder to aggregate backbone outputs
self.resamples = dict()
# set of convoltion layers and upsample layers that are used to
# prepare the FPN processors for output
......@@ -181,7 +172,7 @@ class YoloFPN(tf.keras.layers.Layer):
@tf.keras.utils.register_keras_serializable(package='yolo')
class YoloPAN(tf.keras.layers.Layer):
"""YOLO Path Aggregation Network"""
"""YOLO Path Aggregation Network."""
def __init__(self,
path_process_len=6,
......@@ -216,7 +207,7 @@ class YoloPAN(tf.keras.layers.Layer):
kernel_initializer: kernel_initializer for convolutional layers.
kernel_regularizer: tf.keras.regularizers.Regularizer object for Conv2D.
bias_regularizer: tf.keras.regularizers.Regularizer object for Conv2d.
fpn_input: `bool`, for whether the input into this function is an FPN or
fpn_input: `bool`, for whether the input into this fucntion is an FPN or
a backbone.
fpn_filter_scale: `int`, scaling factor for the FPN filters.
**kwargs: keyword arguments to be passed.
......@@ -253,8 +244,7 @@ class YoloPAN(tf.keras.layers.Layer):
norm_momentum=self._norm_momentum)
def build(self, inputs):
"""Use config dictionary to generate all important attributes for head
construction.
"""Use config dictionary to generate all important attributes for head.
Args:
inputs: dictionary of the shape of input args as a dictionary of lists.
......@@ -270,7 +260,7 @@ class YoloPAN(tf.keras.layers.Layer):
# directly connect to an input path and process it
self.preprocessors = dict()
# resample an input and merge it with the output of another path
# in order to aggregate backbone outputs
# inorder to aggregate backbone outputs
self.resamples = dict()
# FPN will reverse the key process order for the backbone, so we need
......@@ -368,7 +358,7 @@ class YoloPAN(tf.keras.layers.Layer):
@tf.keras.utils.register_keras_serializable(package='yolo')
class YoloDecoder(tf.keras.Model):
"""Darknet Backbone Decoder"""
"""Darknet Backbone Decoder."""
def __init__(self,
input_specs,
......@@ -388,8 +378,10 @@ class YoloDecoder(tf.keras.Model):
kernel_regularizer=None,
bias_regularizer=None,
**kwargs):
"""Yolo Decoder initialization function. A unified model that ties all
decoder components into a conditionally build YOLO decoder.
"""Yolo Decoder initialization function.
A unified model that ties all decoder components into a conditionally build
YOLO decoder.
Args:
input_specs: `dict[str, tf.InputSpec]`: input specs of each of the inputs
......@@ -483,4 +475,4 @@ class YoloDecoder(tf.keras.Model):
@classmethod
def from_config(cls, config, custom_objects=None):
return cls(**config)
return cls(**config)
\ No newline at end of file
......@@ -17,7 +17,6 @@
# Import libraries
from absl.testing import parameterized
import numpy as np
import tensorflow as tf
from tensorflow.python.distribute import combinations
......@@ -27,6 +26,45 @@ from official.vision.beta.projects.yolo.modeling.decoders import yolo_decoder as
class YoloDecoderTest(parameterized.TestCase, tf.test.TestCase):
def _build_yolo_decoder(self, input_specs, name='1'):
# Builds 4 different arbitrary decoders.
if name == '1':
model = decoders.YoloDecoder(
input_specs=input_specs,
embed_spp=False,
use_fpn=False,
max_level_process_len=2,
path_process_len=1,
activation='mish')
elif name == '6spp':
model = decoders.YoloDecoder(
input_specs=input_specs,
embed_spp=True,
use_fpn=False,
max_level_process_len=None,
path_process_len=6,
activation='mish')
elif name == '6sppfpn':
model = decoders.YoloDecoder(
input_specs=input_specs,
embed_spp=True,
use_fpn=True,
max_level_process_len=None,
path_process_len=6,
activation='mish')
elif name == '6':
model = decoders.YoloDecoder(
input_specs=input_specs,
embed_spp=False,
use_fpn=False,
max_level_process_len=None,
path_process_len=6,
activation='mish')
else:
raise NotImplementedError(f'YOLO decoder test {type} not implemented.')
return model
@parameterized.parameters('1', '6spp', '6sppfpn', '6')
def test_network_creation(self, version):
"""Test creation of ResNet family models."""
......@@ -36,10 +74,10 @@ class YoloDecoderTest(parameterized.TestCase, tf.test.TestCase):
'4': [1, 26, 26, 512],
'5': [1, 13, 13, 1024]
}
decoder = build_yolo_decoder(input_shape, version)
decoder = self._build_yolo_decoder(input_shape, version)
inputs = {}
for key in input_shape.keys():
for key in input_shape:
inputs[key] = tf.ones(input_shape[key], dtype=tf.float32)
endpoints = decoder.call(inputs)
......@@ -50,7 +88,7 @@ class YoloDecoderTest(parameterized.TestCase, tf.test.TestCase):
@combinations.generate(
combinations.combine(
strategy=[
strategy_combinations.tpu_strategy,
strategy_combinations.cloud_tpu_strategy,
strategy_combinations.one_device_strategy_gpu,
],
use_sync_bn=[False, True],
......@@ -66,10 +104,10 @@ class YoloDecoderTest(parameterized.TestCase, tf.test.TestCase):
'4': [1, 26, 26, 512],
'5': [1, 13, 13, 1024]
}
decoder = build_yolo_decoder(input_shape, '6')
decoder = self._build_yolo_decoder(input_shape, '6')
inputs = {}
for key in input_shape.keys():
for key in input_shape:
inputs[key] = tf.ones(input_shape[key], dtype=tf.float32)
_ = decoder.call(inputs)
......@@ -84,10 +122,10 @@ class YoloDecoderTest(parameterized.TestCase, tf.test.TestCase):
'4': [1, 26, 26, 512],
'5': [1, 13, 13, 1024]
}
decoder = build_yolo_decoder(input_shape, '6')
decoder = self._build_yolo_decoder(input_shape, '6')
inputs = {}
for key in input_shape.keys():
for key in input_shape:
inputs[key] = tf.ones(input_shape[key], dtype=tf.float32)
_ = decoder(inputs)
......@@ -100,10 +138,10 @@ class YoloDecoderTest(parameterized.TestCase, tf.test.TestCase):
'4': [1, 26, 26, 512],
'5': [1, 13, 13, 1024]
}
decoder = build_yolo_decoder(input_shape, '6')
decoder = self._build_yolo_decoder(input_shape, '6')
inputs = {}
for key in input_shape.keys():
for key in input_shape:
inputs[key] = tf.ones(input_shape[key], dtype=tf.float32)
_ = decoder(inputs)
......@@ -111,44 +149,5 @@ class YoloDecoderTest(parameterized.TestCase, tf.test.TestCase):
decoder_from_config = decoders.YoloDecoder.from_config(config)
self.assertAllEqual(decoder.get_config(), decoder_from_config.get_config())
def build_yolo_decoder(input_specs, type='1'):
if type == '1':
model = decoders.YoloDecoder(
input_specs=input_specs,
embed_spp=False,
use_fpn=False,
max_level_process_len=2,
path_process_len=1,
activation='mish')
elif type == '6spp':
model = decoders.YoloDecoder(
input_specs=input_specs,
embed_spp=True,
use_fpn=False,
max_level_process_len=None,
path_process_len=6,
activation='mish')
elif type == '6sppfpn':
model = decoders.YoloDecoder(
input_specs=input_specs,
embed_spp=True,
use_fpn=True,
max_level_process_len=None,
path_process_len=6,
activation='mish')
elif type == '6':
model = decoders.YoloDecoder(
input_specs=input_specs,
embed_spp=False,
use_fpn=False,
max_level_process_len=None,
path_process_len=6,
activation='mish')
else:
raise NotImplementedError(f"YOLO decoder test {type} not implemented.")
return model
if __name__ == '__main__':
tf.test.main()
tf.test.main()
\ No newline at end of file
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
......@@ -13,12 +13,14 @@
# limitations under the License.
# Lint as: python3
"""Yolo heads."""
import tensorflow as tf
from official.vision.beta.projects.yolo.modeling.layers import nn_blocks
class YoloHead(tf.keras.layers.Layer):
"""YOLO Prediction Head"""
"""YOLO Prediction Head."""
def __init__(self,
min_level,
......@@ -117,4 +119,4 @@ class YoloHead(tf.keras.layers.Layer):
@classmethod
def from_config(cls, config, custom_objects=None):
return cls(**config)
return cls(**config)
\ No newline at end of file
......@@ -13,15 +13,12 @@
# limitations under the License.
# Lint as: python3
"""Tests for YOLO heads."""
"""Tests for yolo heads."""
# Import libraries
from absl.testing import parameterized
import numpy as np
import tensorflow as tf
from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import strategy_combinations
from official.vision.beta.projects.yolo.modeling.heads import yolo_head as heads
......@@ -40,10 +37,11 @@ class YoloDecoderTest(parameterized.TestCase, tf.test.TestCase):
head = heads.YoloHead(3, 5, classes=classes, boxes_per_level=bps)
inputs = {}
for key in input_shape.keys():
for key in input_shape:
inputs[key] = tf.ones(input_shape[key], dtype=tf.float32)
endpoints = head(inputs)
# print(endpoints)
for key in endpoints.keys():
expected_input_shape = input_shape[key]
......@@ -63,7 +61,7 @@ class YoloDecoderTest(parameterized.TestCase, tf.test.TestCase):
head = heads.YoloHead(3, 5, classes=classes, boxes_per_level=bps)
inputs = {}
for key in input_shape.keys():
for key in input_shape:
inputs[key] = tf.ones(input_shape[key], dtype=tf.float32)
_ = head(inputs)
......@@ -71,6 +69,5 @@ class YoloDecoderTest(parameterized.TestCase, tf.test.TestCase):
head_from_config = heads.YoloHead.from_config(configs)
self.assertAllEqual(head.get_config(), head_from_config.get_config())
if __name__ == '__main__':
tf.test.main()
tf.test.main()
\ No newline at end of file
......@@ -13,6 +13,8 @@
# limitations under the License.
# Lint as: python3
from absl.testing import parameterized
import numpy as np
import tensorflow as tf
import numpy as np
from absl.testing import parameterized
......@@ -334,12 +336,13 @@ class DarkRouteProcessTest(tf.test.TestCase, parameterized.TestCase):
test_layer = nn_blocks.DarkRouteProcess(
filters=filters, repetitions=repetitions, insert_spp=spp)
outx = test_layer(x)
self.assertEqual(len(outx), 2, msg='len(outx) != 2')
self.assertLen(outx, 2, msg='len(outx) != 2')
if repetitions == 1:
filter_y1 = filters
else:
filter_y1 = filters // 2
self.assertAllEqual(outx[1].shape.as_list(), [None, width, height, filter_y1])
self.assertAllEqual(
outx[1].shape.as_list(), [None, width, height, filter_y1])
self.assertAllEqual(
filters % 2,
0,
......@@ -366,7 +369,8 @@ class DarkRouteProcessTest(tf.test.TestCase, parameterized.TestCase):
y_0 = tf.Variable(
initial_value=init(shape=(1, width, height, filters), dtype=tf.float32))
y_1 = tf.Variable(
initial_value=init(shape=(1, width, height, filter_y1), dtype=tf.float32))
initial_value=init(
shape=(1, width, height, filter_y1), dtype=tf.float32))
with tf.GradientTape() as tape:
x_hat_0, x_hat_1 = test_layer(x)
......@@ -379,6 +383,5 @@ class DarkRouteProcessTest(tf.test.TestCase, parameterized.TestCase):
self.assertNotIn(None, grad)
return
if __name__ == '__main__':
tf.test.main()
tf.test.main()
\ No newline at end of file
......@@ -34,9 +34,9 @@ class DetectionModule(export_base.ExportModule):
def _build_model(self):
if self._batch_size is None:
ValueError("batch_size can't be None for detection models")
raise ValueError('batch_size cannot be None for detection models.')
if not self.params.task.model.detection_generator.use_batched_nms:
ValueError('Only batched_nms is supported.')
raise ValueError('Only batched_nms is supported.')
input_specs = tf.keras.layers.InputSpec(shape=[self._batch_size] +
self._input_image_size + [3])
......
......@@ -118,6 +118,20 @@ class DetectionExportTest(tf.test.TestCase, parameterized.TestCase):
self.assertAllClose(outputs['num_detections'].numpy(),
expected_outputs['num_detections'].numpy())
def test_build_model_fail_with_none_batch_size(self):
params = exp_factory.get_exp_config('retinanet_resnetfpn_coco')
with self.assertRaisesRegex(
ValueError, 'batch_size cannot be None for detection models.'):
detection.DetectionModule(
params, batch_size=None, input_image_size=[640, 640])
def test_build_model_fail_with_batched_nms_false(self):
params = exp_factory.get_exp_config('retinanet_resnetfpn_coco')
params.task.model.detection_generator.use_batched_nms = False
with self.assertRaisesRegex(ValueError, 'Only batched_nms is supported.'):
detection.DetectionModule(
params, batch_size=1, input_image_size=[640, 640])
if __name__ == '__main__':
tf.test.main()
......@@ -104,6 +104,7 @@ class ImageClassificationTask(base_task.Task):
num_classes=num_classes,
image_field_key=image_field_key,
label_field_key=label_field_key,
decode_jpeg_only=params.decode_jpeg_only,
aug_rand_hflip=params.aug_rand_hflip,
aug_type=params.aug_type,
is_multilabel=is_multilabel,
......
......@@ -133,12 +133,54 @@ class RetinaNetTask(base_task.Task):
return dataset
def build_attribute_loss(self,
attribute_heads: List[exp_cfg.AttributeHead],
outputs: Mapping[str, Any],
labels: Mapping[str, Any],
box_sample_weight: tf.Tensor) -> float:
"""Computes attribute loss.
Args:
attribute_heads: a list of attribute head configs.
outputs: RetinaNet model outputs.
labels: RetinaNet labels.
box_sample_weight: normalized bounding box sample weights.
Returns:
Attribute loss of all attribute heads.
"""
attribute_loss = 0.0
for head in attribute_heads:
if head.name not in labels['attribute_targets']:
raise ValueError(f'Attribute {head.name} not found in label targets.')
if head.name not in outputs['attribute_outputs']:
raise ValueError(f'Attribute {head.name} not found in model outputs.')
y_true_att = keras_cv.losses.multi_level_flatten(
labels['attribute_targets'][head.name], last_dim=head.size)
y_pred_att = keras_cv.losses.multi_level_flatten(
outputs['attribute_outputs'][head.name], last_dim=head.size)
if head.type == 'regression':
att_loss_fn = tf.keras.losses.Huber(
1.0, reduction=tf.keras.losses.Reduction.SUM)
att_loss = att_loss_fn(
y_true=y_true_att,
y_pred=y_pred_att,
sample_weight=box_sample_weight)
else:
raise ValueError(f'Attribute type {head.type} not supported.')
attribute_loss += att_loss
return attribute_loss
def build_losses(self,
outputs: Mapping[str, Any],
labels: Mapping[str, Any],
aux_losses: Optional[Any] = None):
"""Build RetinaNet losses."""
params = self.task_config
attribute_heads = self.task_config.model.head.attribute_heads
cls_loss_fn = keras_cv.losses.FocalLoss(
alpha=params.losses.focal_loss_alpha,
gamma=params.losses.focal_loss_gamma,
......@@ -170,6 +212,10 @@ class RetinaNetTask(base_task.Task):
model_loss = cls_loss + params.losses.box_loss_weight * box_loss
if attribute_heads:
model_loss += self.build_attribute_loss(attribute_heads, outputs, labels,
box_sample_weight)
total_loss = model_loss
if aux_losses:
reg_loss = tf.reduce_sum(aux_losses)
......
......@@ -322,21 +322,21 @@ class DistributedExecutor(object):
return test_step
def train(self,
train_input_fn: Callable[[params_dict.ParamsDict], tf.data.Dataset],
eval_input_fn: Callable[[params_dict.ParamsDict],
tf.data.Dataset] = None,
model_dir: Text = None,
total_steps: int = 1,
iterations_per_loop: int = 1,
train_metric_fn: Callable[[], Any] = None,
eval_metric_fn: Callable[[], Any] = None,
summary_writer_fn: Callable[[Text, Text],
SummaryWriter] = SummaryWriter,
init_checkpoint: Callable[[tf.keras.Model], Any] = None,
custom_callbacks: List[tf.keras.callbacks.Callback] = None,
continuous_eval: bool = False,
save_config: bool = True):
def train(
self,
train_input_fn: Callable[[params_dict.ParamsDict], tf.data.Dataset],
eval_input_fn: Optional[Callable[[params_dict.ParamsDict],
tf.data.Dataset]] = None,
model_dir: Optional[Text] = None,
total_steps: int = 1,
iterations_per_loop: int = 1,
train_metric_fn: Optional[Callable[[], Any]] = None,
eval_metric_fn: Optional[Callable[[], Any]] = None,
summary_writer_fn: Callable[[Text, Text], SummaryWriter] = SummaryWriter,
init_checkpoint: Optional[Callable[[tf.keras.Model], Any]] = None,
custom_callbacks: Optional[List[tf.keras.callbacks.Callback]] = None,
continuous_eval: bool = False,
save_config: bool = True):
"""Runs distributed training.
Args:
......@@ -590,7 +590,7 @@ class DistributedExecutor(object):
eval_input_fn: Callable[[params_dict.ParamsDict], tf.data.Dataset],
eval_metric_fn: Callable[[], Any],
total_steps: int = -1,
eval_timeout: int = None,
eval_timeout: Optional[int] = None,
min_eval_interval: int = 180,
summary_writer_fn: Callable[[Text, Text], SummaryWriter] = SummaryWriter):
"""Runs distributed evaluation on model folder.
......@@ -646,7 +646,7 @@ class DistributedExecutor(object):
eval_input_fn: Callable[[params_dict.ParamsDict],
tf.data.Dataset],
eval_metric_fn: Callable[[], Any],
summary_writer: SummaryWriter = None):
summary_writer: Optional[SummaryWriter] = None):
"""Runs distributed evaluation on the one checkpoint.
Args:
......
......@@ -20,7 +20,7 @@ from __future__ import division
from __future__ import print_function
import os
from typing import Any, List, MutableMapping, Text
from typing import Any, List, MutableMapping, Optional, Text
from absl import logging
import tensorflow as tf
......@@ -39,7 +39,7 @@ def get_callbacks(
initial_step: int = 0,
batch_size: int = 0,
log_steps: int = 0,
model_dir: str = None,
model_dir: Optional[str] = None,
backup_and_restore: bool = False) -> List[tf.keras.callbacks.Callback]:
"""Get all callbacks."""
model_dir = model_dir or ''
......@@ -120,7 +120,7 @@ class CustomTensorBoard(tf.keras.callbacks.TensorBoard):
def on_batch_begin(self,
epoch: int,
logs: MutableMapping[str, Any] = None) -> None:
logs: Optional[MutableMapping[str, Any]] = None) -> None:
self.step += 1
if logs is None:
logs = {}
......@@ -129,7 +129,7 @@ class CustomTensorBoard(tf.keras.callbacks.TensorBoard):
def on_epoch_begin(self,
epoch: int,
logs: MutableMapping[str, Any] = None) -> None:
logs: Optional[MutableMapping[str, Any]] = None) -> None:
if logs is None:
logs = {}
metrics = self._calculate_metrics()
......@@ -140,7 +140,7 @@ class CustomTensorBoard(tf.keras.callbacks.TensorBoard):
def on_epoch_end(self,
epoch: int,
logs: MutableMapping[str, Any] = None) -> None:
logs: Optional[MutableMapping[str, Any]] = None) -> None:
if logs is None:
logs = {}
metrics = self._calculate_metrics()
......@@ -195,13 +195,13 @@ class MovingAverageCallback(tf.keras.callbacks.Callback):
optimization.ExponentialMovingAverage)
self.model.optimizer.shadow_copy(self.model)
def on_test_begin(self, logs: MutableMapping[Text, Any] = None):
def on_test_begin(self, logs: Optional[MutableMapping[Text, Any]] = None):
self.model.optimizer.swap_weights()
def on_test_end(self, logs: MutableMapping[Text, Any] = None):
def on_test_end(self, logs: Optional[MutableMapping[Text, Any]] = None):
self.model.optimizer.swap_weights()
def on_train_end(self, logs: MutableMapping[Text, Any] = None):
def on_train_end(self, logs: Optional[MutableMapping[Text, Any]] = None):
if self.overwrite_weights_on_train_end:
self.model.optimizer.assign_average_vars(self.model.variables)
......
......@@ -280,7 +280,9 @@ class DatasetBuilder:
raise e
return self.builder_info
def build(self, strategy: tf.distribute.Strategy = None) -> tf.data.Dataset:
def build(
self,
strategy: Optional[tf.distribute.Strategy] = None) -> tf.data.Dataset:
"""Construct a dataset end-to-end and return it using an optional strategy.
Args:
......@@ -305,7 +307,8 @@ class DatasetBuilder:
def _build(
self,
input_context: tf.distribute.InputContext = None) -> tf.data.Dataset:
input_context: Optional[tf.distribute.InputContext] = None
) -> tf.data.Dataset:
"""Construct a dataset end-to-end and return it.
Args:
......
......@@ -160,9 +160,9 @@ def conv2d_block(inputs: tf.Tensor,
strides: Any = (1, 1),
use_batch_norm: bool = True,
use_bias: bool = False,
activation: Any = None,
activation: Optional[Any] = None,
depthwise: bool = False,
name: Text = None):
name: Optional[Text] = None):
"""A conv2d followed by batch norm and an activation."""
batch_norm = common_modules.get_batch_norm(config.batch_norm)
bn_momentum = config.bn_momentum
......@@ -212,7 +212,7 @@ def conv2d_block(inputs: tf.Tensor,
def mb_conv_block(inputs: tf.Tensor,
block: BlockConfig,
config: ModelConfig,
prefix: Text = None):
prefix: Optional[Text] = None):
"""Mobile Inverted Residual Bottleneck.
Args:
......@@ -432,8 +432,8 @@ class EfficientNet(tf.keras.Model):
"""
def __init__(self,
config: ModelConfig = None,
overrides: Dict[Text, Any] = None):
config: Optional[ModelConfig] = None,
overrides: Optional[Dict[Text, Any]] = None):
"""Create an EfficientNet model.
Args:
......@@ -463,9 +463,9 @@ class EfficientNet(tf.keras.Model):
@classmethod
def from_name(cls,
model_name: Text,
model_weights_path: Text = None,
model_weights_path: Optional[Text] = None,
weights_format: Text = 'saved_model',
overrides: Dict[Text, Any] = None):
overrides: Optional[Dict[Text, Any]] = None):
"""Construct an EfficientNet model from a predefined model name.
E.g., `EfficientNet.from_name('efficientnet-b0')`.
......
......@@ -18,7 +18,7 @@ from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
from typing import Any, Dict, Text
from typing import Any, Dict, Optional, Text
from absl import logging
import tensorflow as tf
......@@ -35,7 +35,7 @@ def build_optimizer(
optimizer_name: Text,
base_learning_rate: tf.keras.optimizers.schedules.LearningRateSchedule,
params: Dict[Text, Any],
model: tf.keras.Model = None):
model: Optional[tf.keras.Model] = None):
"""Build the optimizer based on name.
Args:
......@@ -124,9 +124,9 @@ def build_optimizer(
def build_learning_rate(params: base_configs.LearningRateConfig,
batch_size: int = None,
train_epochs: int = None,
train_steps: int = None):
batch_size: Optional[int] = None,
train_epochs: Optional[int] = None,
train_steps: Optional[int] = None):
"""Build the learning rate given the provided configuration."""
decay_type = params.name
base_lr = params.initial_lr
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment