Commit 790e49e5 authored by stephenwu's avatar stephenwu
Browse files

Merge branch 'master' of https://github.com/tensorflow/models into run_superglue

parents 8ab018b0 5bb827c3
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2020 The TensorFlow Authors. All Rights Reserved. # Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ==============================================================================
"""The COCO-style evaluator. """The COCO-style evaluator.
The following snippet demonstrates the use of interfaces: The following snippet demonstrates the use of interfaces:
......
# Copyright 2020 The TensorFlow Authors. All Rights Reserved. # Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ==============================================================================
"""Util functions related to pycocotools and COCO eval.""" """Util functions related to pycocotools and COCO eval."""
import copy import copy
......
# Copyright 2020 The TensorFlow Authors. All Rights Reserved. # Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ==============================================================================
"""Metrics for segmentation.""" """Metrics for segmentation."""
import tensorflow as tf import tensorflow as tf
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2020 The TensorFlow Authors. All Rights Reserved. # Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ==============================================================================
"""Losses for maskrcn model.""" """Losses for maskrcn model."""
# Import libraries # Import libraries
......
# Copyright 2018 The TensorFlow Authors. All Rights Reserved. # Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ==============================================================================
"""Losses used for detection models.""" """Losses used for detection models."""
# Import libraries # Import libraries
......
# Copyright 2020 The TensorFlow Authors. All Rights Reserved. # Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ==============================================================================
"""Losses used for segmentation models.""" """Losses used for segmentation models."""
# Import libraries # Import libraries
......
# Lint as: python3 # Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -12,8 +11,11 @@ ...@@ -12,8 +11,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ==============================================================================
# Lint as: python3
"""Modeling package definition.""" """Modeling package definition."""
from official.vision.beta.modeling import backbones from official.vision.beta.modeling import backbones
from official.vision.beta.modeling import decoders from official.vision.beta.modeling import decoders
from official.vision.beta.modeling import heads
from official.vision.beta.modeling import layers
# Lint as: python3 # Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -12,7 +11,8 @@ ...@@ -12,7 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ==============================================================================
# Lint as: python3
"""Backbones package definition.""" """Backbones package definition."""
from official.vision.beta.modeling.backbones.efficientnet import EfficientNet from official.vision.beta.modeling.backbones.efficientnet import EfficientNet
...@@ -22,3 +22,4 @@ from official.vision.beta.modeling.backbones.resnet_3d import ResNet3D ...@@ -22,3 +22,4 @@ from official.vision.beta.modeling.backbones.resnet_3d import ResNet3D
from official.vision.beta.modeling.backbones.resnet_deeplab import DilatedResNet from official.vision.beta.modeling.backbones.resnet_deeplab import DilatedResNet
from official.vision.beta.modeling.backbones.revnet import RevNet from official.vision.beta.modeling.backbones.revnet import RevNet
from official.vision.beta.modeling.backbones.spinenet import SpineNet from official.vision.beta.modeling.backbones.spinenet import SpineNet
from official.vision.beta.modeling.backbones.spinenet_mobile import SpineNetMobile
# Copyright 2020 The TensorFlow Authors. All Rights Reserved. # Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ==============================================================================
"""Contains definitions of EfficientNet Networks.""" """Contains definitions of EfficientNet Networks."""
import math import math
......
# Lint as: python3 # Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -12,7 +11,8 @@ ...@@ -12,7 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ==============================================================================
# Lint as: python3
"""Tests for EfficientNet.""" """Tests for EfficientNet."""
# Import libraries # Import libraries
......
# Lint as: python3 # Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -12,7 +11,8 @@ ...@@ -12,7 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ==============================================================================
# Lint as: python3
"""Backbone registers and factory method. """Backbone registers and factory method.
One can regitered a new backbone model by the following two steps: One can regitered a new backbone model by the following two steps:
......
# Lint as: python3 # Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -12,7 +11,8 @@ ...@@ -12,7 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ==============================================================================
# Lint as: python3
"""Tests for factory functions.""" """Tests for factory functions."""
# Import libraries # Import libraries
from absl.testing import parameterized from absl.testing import parameterized
......
# Copyright 2020 The TensorFlow Authors. All Rights Reserved. # Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ==============================================================================
"""Contains definitions of MobileNet Networks.""" """Contains definitions of MobileNet Networks."""
from typing import Optional, Dict, Any, Tuple from typing import Optional, Dict, Any, Tuple
...@@ -148,22 +148,23 @@ Weijun Wang, Tobias Weyand, Marco Andreetto, Hartwig Adam ...@@ -148,22 +148,23 @@ Weijun Wang, Tobias Weyand, Marco Andreetto, Hartwig Adam
""" """
MNV1_BLOCK_SPECS = { MNV1_BLOCK_SPECS = {
'spec_name': 'MobileNetV1', 'spec_name': 'MobileNetV1',
'block_spec_schema': ['block_fn', 'kernel_size', 'strides', 'filters'], 'block_spec_schema': ['block_fn', 'kernel_size', 'strides',
'filters', 'is_output'],
'block_specs': [ 'block_specs': [
('convbn', 3, 2, 32), ('convbn', 3, 2, 32, False),
('depsepconv', 3, 1, 64), ('depsepconv', 3, 1, 64, False),
('depsepconv', 3, 2, 128), ('depsepconv', 3, 2, 128, False),
('depsepconv', 3, 1, 128), ('depsepconv', 3, 1, 128, True),
('depsepconv', 3, 2, 256), ('depsepconv', 3, 2, 256, False),
('depsepconv', 3, 1, 256), ('depsepconv', 3, 1, 256, True),
('depsepconv', 3, 2, 512), ('depsepconv', 3, 2, 512, False),
('depsepconv', 3, 1, 512), ('depsepconv', 3, 1, 512, False),
('depsepconv', 3, 1, 512), ('depsepconv', 3, 1, 512, False),
('depsepconv', 3, 1, 512), ('depsepconv', 3, 1, 512, False),
('depsepconv', 3, 1, 512), ('depsepconv', 3, 1, 512, False),
('depsepconv', 3, 1, 512), ('depsepconv', 3, 1, 512, True),
('depsepconv', 3, 2, 1024), ('depsepconv', 3, 2, 1024, False),
('depsepconv', 3, 1, 1024), ('depsepconv', 3, 1, 1024, True),
] ]
} }
...@@ -176,27 +177,27 @@ Mark Sandler, Andrew Howard, Menglong Zhu, Andrey Zhmoginov, Liang-Chieh Chen ...@@ -176,27 +177,27 @@ Mark Sandler, Andrew Howard, Menglong Zhu, Andrey Zhmoginov, Liang-Chieh Chen
MNV2_BLOCK_SPECS = { MNV2_BLOCK_SPECS = {
'spec_name': 'MobileNetV2', 'spec_name': 'MobileNetV2',
'block_spec_schema': ['block_fn', 'kernel_size', 'strides', 'filters', 'block_spec_schema': ['block_fn', 'kernel_size', 'strides', 'filters',
'expand_ratio'], 'expand_ratio', 'is_output'],
'block_specs': [ 'block_specs': [
('convbn', 3, 2, 32, None), ('convbn', 3, 2, 32, None, False),
('invertedbottleneck', 3, 1, 16, 1.), ('invertedbottleneck', 3, 1, 16, 1., False),
('invertedbottleneck', 3, 2, 24, 6.), ('invertedbottleneck', 3, 2, 24, 6., False),
('invertedbottleneck', 3, 1, 24, 6.), ('invertedbottleneck', 3, 1, 24, 6., True),
('invertedbottleneck', 3, 2, 32, 6.), ('invertedbottleneck', 3, 2, 32, 6., False),
('invertedbottleneck', 3, 1, 32, 6.), ('invertedbottleneck', 3, 1, 32, 6., False),
('invertedbottleneck', 3, 1, 32, 6.), ('invertedbottleneck', 3, 1, 32, 6., True),
('invertedbottleneck', 3, 2, 64, 6.), ('invertedbottleneck', 3, 2, 64, 6., False),
('invertedbottleneck', 3, 1, 64, 6.), ('invertedbottleneck', 3, 1, 64, 6., False),
('invertedbottleneck', 3, 1, 64, 6.), ('invertedbottleneck', 3, 1, 64, 6., False),
('invertedbottleneck', 3, 1, 64, 6.), ('invertedbottleneck', 3, 1, 64, 6., False),
('invertedbottleneck', 3, 1, 96, 6.), ('invertedbottleneck', 3, 1, 96, 6., False),
('invertedbottleneck', 3, 1, 96, 6.), ('invertedbottleneck', 3, 1, 96, 6., False),
('invertedbottleneck', 3, 1, 96, 6.), ('invertedbottleneck', 3, 1, 96, 6., True),
('invertedbottleneck', 3, 2, 160, 6.), ('invertedbottleneck', 3, 2, 160, 6., False),
('invertedbottleneck', 3, 1, 160, 6.), ('invertedbottleneck', 3, 1, 160, 6., False),
('invertedbottleneck', 3, 1, 160, 6.), ('invertedbottleneck', 3, 1, 160, 6., False),
('invertedbottleneck', 3, 1, 320, 6.), ('invertedbottleneck', 3, 1, 320, 6., True),
('convbn', 1, 1, 1280, None), ('convbn', 1, 1, 1280, None, False),
] ]
} }
...@@ -211,27 +212,46 @@ MNV3Large_BLOCK_SPECS = { ...@@ -211,27 +212,46 @@ MNV3Large_BLOCK_SPECS = {
'spec_name': 'MobileNetV3Large', 'spec_name': 'MobileNetV3Large',
'block_spec_schema': ['block_fn', 'kernel_size', 'strides', 'filters', 'block_spec_schema': ['block_fn', 'kernel_size', 'strides', 'filters',
'activation', 'se_ratio', 'expand_ratio', 'activation', 'se_ratio', 'expand_ratio',
'use_normalization', 'use_bias'], 'use_normalization', 'use_bias', 'is_output'],
'block_specs': [ 'block_specs': [
('convbn', 3, 2, 16, 'hard_swish', None, None, True, False), ('convbn', 3, 2, 16,
('invertedbottleneck', 3, 1, 16, 'relu', None, 1., None, False), 'hard_swish', None, None, True, False, False),
('invertedbottleneck', 3, 2, 24, 'relu', None, 4., None, False), ('invertedbottleneck', 3, 1, 16,
('invertedbottleneck', 3, 1, 24, 'relu', None, 3., None, False), 'relu', None, 1., None, False, False),
('invertedbottleneck', 5, 2, 40, 'relu', 0.25, 3., None, False), ('invertedbottleneck', 3, 2, 24,
('invertedbottleneck', 5, 1, 40, 'relu', 0.25, 3., None, False), 'relu', None, 4., None, False, False),
('invertedbottleneck', 5, 1, 40, 'relu', 0.25, 3., None, False), ('invertedbottleneck', 3, 1, 24,
('invertedbottleneck', 3, 2, 80, 'hard_swish', None, 6., None, False), 'relu', None, 3., None, False, True),
('invertedbottleneck', 3, 1, 80, 'hard_swish', None, 2.5, None, False), ('invertedbottleneck', 5, 2, 40,
('invertedbottleneck', 3, 1, 80, 'hard_swish', None, 2.3, None, False), 'relu', 0.25, 3., None, False, False),
('invertedbottleneck', 3, 1, 80, 'hard_swish', None, 2.3, None, False), ('invertedbottleneck', 5, 1, 40,
('invertedbottleneck', 3, 1, 112, 'hard_swish', 0.25, 6., None, False), 'relu', 0.25, 3., None, False, False),
('invertedbottleneck', 3, 1, 112, 'hard_swish', 0.25, 6., None, False), ('invertedbottleneck', 5, 1, 40,
('invertedbottleneck', 5, 2, 160, 'hard_swish', 0.25, 6., None, False), 'relu', 0.25, 3., None, False, True),
('invertedbottleneck', 5, 1, 160, 'hard_swish', 0.25, 6., None, False), ('invertedbottleneck', 3, 2, 80,
('invertedbottleneck', 5, 1, 160, 'hard_swish', 0.25, 6., None, False), 'hard_swish', None, 6., None, False, False),
('convbn', 1, 1, 960, 'hard_swish', None, None, True, False), ('invertedbottleneck', 3, 1, 80,
('gpooling', None, None, None, None, None, None, None, None), 'hard_swish', None, 2.5, None, False, False),
('convbn', 1, 1, 1280, 'hard_swish', None, None, False, True), ('invertedbottleneck', 3, 1, 80,
'hard_swish', None, 2.3, None, False, False),
('invertedbottleneck', 3, 1, 80,
'hard_swish', None, 2.3, None, False, False),
('invertedbottleneck', 3, 1, 112,
'hard_swish', 0.25, 6., None, False, False),
('invertedbottleneck', 3, 1, 112,
'hard_swish', 0.25, 6., None, False, True),
('invertedbottleneck', 5, 2, 160,
'hard_swish', 0.25, 6., None, False, False),
('invertedbottleneck', 5, 1, 160,
'hard_swish', 0.25, 6., None, False, False),
('invertedbottleneck', 5, 1, 160,
'hard_swish', 0.25, 6., None, False, True),
('convbn', 1, 1, 960,
'hard_swish', None, None, True, False, False),
('gpooling', None, None, None,
None, None, None, None, None, False),
('convbn', 1, 1, 1280,
'hard_swish', None, None, False, True, False),
] ]
} }
...@@ -239,23 +259,38 @@ MNV3Small_BLOCK_SPECS = { ...@@ -239,23 +259,38 @@ MNV3Small_BLOCK_SPECS = {
'spec_name': 'MobileNetV3Small', 'spec_name': 'MobileNetV3Small',
'block_spec_schema': ['block_fn', 'kernel_size', 'strides', 'filters', 'block_spec_schema': ['block_fn', 'kernel_size', 'strides', 'filters',
'activation', 'se_ratio', 'expand_ratio', 'activation', 'se_ratio', 'expand_ratio',
'use_normalization', 'use_bias'], 'use_normalization', 'use_bias', 'is_output'],
'block_specs': [ 'block_specs': [
('convbn', 3, 2, 16, 'hard_swish', None, None, True, False), ('convbn', 3, 2, 16,
('invertedbottleneck', 3, 2, 16, 'relu', 0.25, 1, None, False), 'hard_swish', None, None, True, False, False),
('invertedbottleneck', 3, 2, 24, 'relu', None, 72. / 16, None, False), ('invertedbottleneck', 3, 2, 16,
('invertedbottleneck', 3, 1, 24, 'relu', None, 88. / 24, None, False), 'relu', 0.25, 1, None, False, True),
('invertedbottleneck', 5, 2, 40, 'hard_swish', 0.25, 4., None, False), ('invertedbottleneck', 3, 2, 24,
('invertedbottleneck', 5, 1, 40, 'hard_swish', 0.25, 6., None, False), 'relu', None, 72. / 16, None, False, False),
('invertedbottleneck', 5, 1, 40, 'hard_swish', 0.25, 6., None, False), ('invertedbottleneck', 3, 1, 24,
('invertedbottleneck', 5, 1, 48, 'hard_swish', 0.25, 3., None, False), 'relu', None, 88. / 24, None, False, True),
('invertedbottleneck', 5, 1, 48, 'hard_swish', 0.25, 3., None, False), ('invertedbottleneck', 5, 2, 40,
('invertedbottleneck', 5, 2, 96, 'hard_swish', 0.25, 6., None, False), 'hard_swish', 0.25, 4., None, False, False),
('invertedbottleneck', 5, 1, 96, 'hard_swish', 0.25, 6., None, False), ('invertedbottleneck', 5, 1, 40,
('invertedbottleneck', 5, 1, 96, 'hard_swish', 0.25, 6., None, False), 'hard_swish', 0.25, 6., None, False, False),
('convbn', 1, 1, 576, 'hard_swish', None, None, True, False), ('invertedbottleneck', 5, 1, 40,
('gpooling', None, None, None, None, None, None, None, None), 'hard_swish', 0.25, 6., None, False, False),
('convbn', 1, 1, 1024, 'hard_swish', None, None, False, True), ('invertedbottleneck', 5, 1, 48,
'hard_swish', 0.25, 3., None, False, False),
('invertedbottleneck', 5, 1, 48,
'hard_swish', 0.25, 3., None, False, True),
('invertedbottleneck', 5, 2, 96,
'hard_swish', 0.25, 6., None, False, False),
('invertedbottleneck', 5, 1, 96,
'hard_swish', 0.25, 6., None, False, False),
('invertedbottleneck', 5, 1, 96,
'hard_swish', 0.25, 6., None, False, True),
('convbn', 1, 1, 576,
'hard_swish', None, None, True, False, False),
('gpooling', None, None, None,
None, None, None, None, None, False),
('convbn', 1, 1, 1024,
'hard_swish', None, None, False, True, False),
] ]
} }
...@@ -267,32 +302,32 @@ MNV3EdgeTPU_BLOCK_SPECS = { ...@@ -267,32 +302,32 @@ MNV3EdgeTPU_BLOCK_SPECS = {
'spec_name': 'MobileNetV3EdgeTPU', 'spec_name': 'MobileNetV3EdgeTPU',
'block_spec_schema': ['block_fn', 'kernel_size', 'strides', 'filters', 'block_spec_schema': ['block_fn', 'kernel_size', 'strides', 'filters',
'activation', 'se_ratio', 'expand_ratio', 'activation', 'se_ratio', 'expand_ratio',
'use_residual', 'use_depthwise'], 'use_residual', 'use_depthwise', 'is_output'],
'block_specs': [ 'block_specs': [
('convbn', 3, 2, 32, 'relu', None, None, None, None), ('convbn', 3, 2, 32, 'relu', None, None, None, None, False),
('invertedbottleneck', 3, 1, 16, 'relu', None, 1., True, False), ('invertedbottleneck', 3, 1, 16, 'relu', None, 1., True, False, False),
('invertedbottleneck', 3, 2, 32, 'relu', None, 8., True, False), ('invertedbottleneck', 3, 2, 32, 'relu', None, 8., True, False, False),
('invertedbottleneck', 3, 1, 32, 'relu', None, 4., True, False), ('invertedbottleneck', 3, 1, 32, 'relu', None, 4., True, False, False),
('invertedbottleneck', 3, 1, 32, 'relu', None, 4., True, False), ('invertedbottleneck', 3, 1, 32, 'relu', None, 4., True, False, False),
('invertedbottleneck', 3, 1, 32, 'relu', None, 4., True, False), ('invertedbottleneck', 3, 1, 32, 'relu', None, 4., True, False, True),
('invertedbottleneck', 3, 2, 48, 'relu', None, 8., True, False), ('invertedbottleneck', 3, 2, 48, 'relu', None, 8., True, False, False),
('invertedbottleneck', 3, 1, 48, 'relu', None, 4., True, False), ('invertedbottleneck', 3, 1, 48, 'relu', None, 4., True, False, False),
('invertedbottleneck', 3, 1, 48, 'relu', None, 4., True, False), ('invertedbottleneck', 3, 1, 48, 'relu', None, 4., True, False, False),
('invertedbottleneck', 3, 1, 48, 'relu', None, 4., True, False), ('invertedbottleneck', 3, 1, 48, 'relu', None, 4., True, False, True),
('invertedbottleneck', 3, 2, 96, 'relu', None, 8., True, True), ('invertedbottleneck', 3, 2, 96, 'relu', None, 8., True, True, False),
('invertedbottleneck', 3, 1, 96, 'relu', None, 4., True, True), ('invertedbottleneck', 3, 1, 96, 'relu', None, 4., True, True, False),
('invertedbottleneck', 3, 1, 96, 'relu', None, 4., True, True), ('invertedbottleneck', 3, 1, 96, 'relu', None, 4., True, True, False),
('invertedbottleneck', 3, 1, 96, 'relu', None, 4., True, True), ('invertedbottleneck', 3, 1, 96, 'relu', None, 4., True, True, False),
('invertedbottleneck', 3, 1, 96, 'relu', None, 8., False, True), ('invertedbottleneck', 3, 1, 96, 'relu', None, 8., False, True, False),
('invertedbottleneck', 3, 1, 96, 'relu', None, 4., True, True), ('invertedbottleneck', 3, 1, 96, 'relu', None, 4., True, True, False),
('invertedbottleneck', 3, 1, 96, 'relu', None, 4., True, True), ('invertedbottleneck', 3, 1, 96, 'relu', None, 4., True, True, False),
('invertedbottleneck', 3, 1, 96, 'relu', None, 4., True, True), ('invertedbottleneck', 3, 1, 96, 'relu', None, 4., True, True, True),
('invertedbottleneck', 5, 2, 160, 'relu', None, 8., True, True), ('invertedbottleneck', 5, 2, 160, 'relu', None, 8., True, True, False),
('invertedbottleneck', 5, 1, 160, 'relu', None, 4., True, True), ('invertedbottleneck', 5, 1, 160, 'relu', None, 4., True, True, False),
('invertedbottleneck', 5, 1, 160, 'relu', None, 4., True, True), ('invertedbottleneck', 5, 1, 160, 'relu', None, 4., True, True, False),
('invertedbottleneck', 5, 1, 160, 'relu', None, 4., True, True), ('invertedbottleneck', 5, 1, 160, 'relu', None, 4., True, True, False),
('invertedbottleneck', 3, 1, 192, 'relu', None, 8., True, True), ('invertedbottleneck', 3, 1, 192, 'relu', None, 8., True, True, True),
('convbn', 1, 1, 1280, 'relu', None, None, None, None), ('convbn', 1, 1, 1280, 'relu', None, None, None, None, False),
] ]
} }
...@@ -308,26 +343,26 @@ MNMultiMAX_BLOCK_SPECS = { ...@@ -308,26 +343,26 @@ MNMultiMAX_BLOCK_SPECS = {
'spec_name': 'MobileNetMultiMAX', 'spec_name': 'MobileNetMultiMAX',
'block_spec_schema': ['block_fn', 'kernel_size', 'strides', 'filters', 'block_spec_schema': ['block_fn', 'kernel_size', 'strides', 'filters',
'activation', 'expand_ratio', 'activation', 'expand_ratio',
'use_normalization', 'use_bias'], 'use_normalization', 'use_bias', 'is_output'],
'block_specs': [ 'block_specs': [
('convbn', 3, 2, 32, 'relu', None, True, False), ('convbn', 3, 2, 32, 'relu', None, True, False, False),
('invertedbottleneck', 3, 2, 32, 'relu', 3., None, False), ('invertedbottleneck', 3, 2, 32, 'relu', 3., None, False, True),
('invertedbottleneck', 5, 2, 64, 'relu', 6., None, False), ('invertedbottleneck', 5, 2, 64, 'relu', 6., None, False, False),
('invertedbottleneck', 3, 1, 64, 'relu', 2., None, False), ('invertedbottleneck', 3, 1, 64, 'relu', 2., None, False, False),
('invertedbottleneck', 3, 1, 64, 'relu', 2., None, False), ('invertedbottleneck', 3, 1, 64, 'relu', 2., None, False, True),
('invertedbottleneck', 5, 2, 128, 'relu', 6., None, False), ('invertedbottleneck', 5, 2, 128, 'relu', 6., None, False, False),
('invertedbottleneck', 3, 1, 128, 'relu', 4., None, False), ('invertedbottleneck', 3, 1, 128, 'relu', 4., None, False, False),
('invertedbottleneck', 3, 1, 128, 'relu', 3., None, False), ('invertedbottleneck', 3, 1, 128, 'relu', 3., None, False, False),
('invertedbottleneck', 3, 1, 128, 'relu', 3., None, False), ('invertedbottleneck', 3, 1, 128, 'relu', 3., None, False, False),
('invertedbottleneck', 3, 1, 128, 'relu', 6., None, False), ('invertedbottleneck', 3, 1, 128, 'relu', 6., None, False, False),
('invertedbottleneck', 3, 1, 128, 'relu', 3., None, False), ('invertedbottleneck', 3, 1, 128, 'relu', 3., None, False, True),
('invertedbottleneck', 3, 2, 160, 'relu', 6., None, False), ('invertedbottleneck', 3, 2, 160, 'relu', 6., None, False, False),
('invertedbottleneck', 5, 1, 160, 'relu', 4., None, False), ('invertedbottleneck', 5, 1, 160, 'relu', 4., None, False, False),
('invertedbottleneck', 3, 1, 160, 'relu', 5., None, False), ('invertedbottleneck', 3, 1, 160, 'relu', 5., None, False, False),
('invertedbottleneck', 5, 1, 160, 'relu', 4., None, False), ('invertedbottleneck', 5, 1, 160, 'relu', 4., None, False, True),
('convbn', 1, 1, 960, 'relu', None, True, False), ('convbn', 1, 1, 960, 'relu', None, True, False, False),
('gpooling', None, None, None, None, None, None, None), ('gpooling', None, None, None, None, None, None, None, False),
('convbn', 1, 1, 1280, 'relu', None, False, True), ('convbn', 1, 1, 1280, 'relu', None, False, True, False),
] ]
} }
...@@ -335,28 +370,28 @@ MNMultiAVG_BLOCK_SPECS = { ...@@ -335,28 +370,28 @@ MNMultiAVG_BLOCK_SPECS = {
'spec_name': 'MobileNetMultiAVG', 'spec_name': 'MobileNetMultiAVG',
'block_spec_schema': ['block_fn', 'kernel_size', 'strides', 'filters', 'block_spec_schema': ['block_fn', 'kernel_size', 'strides', 'filters',
'activation', 'expand_ratio', 'activation', 'expand_ratio',
'use_normalization', 'use_bias'], 'use_normalization', 'use_bias', 'is_output'],
'block_specs': [ 'block_specs': [
('convbn', 3, 2, 32, 'relu', None, True, False), ('convbn', 3, 2, 32, 'relu', None, True, False, False),
('invertedbottleneck', 3, 2, 32, 'relu', 3., None, False), ('invertedbottleneck', 3, 2, 32, 'relu', 3., None, False, False),
('invertedbottleneck', 3, 1, 32, 'relu', 2., None, False), ('invertedbottleneck', 3, 1, 32, 'relu', 2., None, False, True),
('invertedbottleneck', 5, 2, 64, 'relu', 5., None, False), ('invertedbottleneck', 5, 2, 64, 'relu', 5., None, False, False),
('invertedbottleneck', 3, 1, 64, 'relu', 3., None, False), ('invertedbottleneck', 3, 1, 64, 'relu', 3., None, False, False),
('invertedbottleneck', 3, 1, 64, 'relu', 2., None, False), ('invertedbottleneck', 3, 1, 64, 'relu', 2., None, False, False),
('invertedbottleneck', 3, 1, 64, 'relu', 3., None, False), ('invertedbottleneck', 3, 1, 64, 'relu', 3., None, False, True),
('invertedbottleneck', 5, 2, 128, 'relu', 6., None, False), ('invertedbottleneck', 5, 2, 128, 'relu', 6., None, False, False),
('invertedbottleneck', 3, 1, 128, 'relu', 3., None, False), ('invertedbottleneck', 3, 1, 128, 'relu', 3., None, False, False),
('invertedbottleneck', 3, 1, 128, 'relu', 3., None, False), ('invertedbottleneck', 3, 1, 128, 'relu', 3., None, False, False),
('invertedbottleneck', 3, 1, 128, 'relu', 3., None, False), ('invertedbottleneck', 3, 1, 128, 'relu', 3., None, False, False),
('invertedbottleneck', 3, 1, 160, 'relu', 6., None, False), ('invertedbottleneck', 3, 1, 160, 'relu', 6., None, False, False),
('invertedbottleneck', 3, 1, 160, 'relu', 4., None, False), ('invertedbottleneck', 3, 1, 160, 'relu', 4., None, False, True),
('invertedbottleneck', 3, 2, 192, 'relu', 6., None, False), ('invertedbottleneck', 3, 2, 192, 'relu', 6., None, False, False),
('invertedbottleneck', 5, 1, 192, 'relu', 4., None, False), ('invertedbottleneck', 5, 1, 192, 'relu', 4., None, False, False),
('invertedbottleneck', 5, 1, 192, 'relu', 4., None, False), ('invertedbottleneck', 5, 1, 192, 'relu', 4., None, False, False),
('invertedbottleneck', 5, 1, 192, 'relu', 4., None, False), ('invertedbottleneck', 5, 1, 192, 'relu', 4., None, False, True),
('convbn', 1, 1, 960, 'relu', None, True, False), ('convbn', 1, 1, 960, 'relu', None, True, False, False),
('gpooling', None, None, None, None, None, None, None), ('gpooling', None, None, None, None, None, None, None, False),
('convbn', 1, 1, 1280, 'relu', None, False, True), ('convbn', 1, 1, 1280, 'relu', None, False, True, False),
] ]
} }
...@@ -388,6 +423,7 @@ class BlockSpec(hyperparams.Config): ...@@ -388,6 +423,7 @@ class BlockSpec(hyperparams.Config):
se_ratio: Optional[float] = None se_ratio: Optional[float] = None
use_depthwise: bool = True use_depthwise: bool = True
use_residual: bool = True use_residual: bool = True
is_output: bool = True
def block_spec_decoder(specs: Dict[Any, Any], def block_spec_decoder(specs: Dict[Any, Any],
...@@ -552,9 +588,9 @@ class MobileNet(tf.keras.Model): ...@@ -552,9 +588,9 @@ class MobileNet(tf.keras.Model):
divisible_by=self._get_divisible_by(), divisible_by=self._get_divisible_by(),
finegrain_classification_mode=self._finegrain_classification_mode) finegrain_classification_mode=self._finegrain_classification_mode)
x, endpoints = self._mobilenet_base(inputs=inputs) x, endpoints, next_endpoint_level = self._mobilenet_base(inputs=inputs)
endpoints[max(endpoints.keys()) + 1] = x endpoints[str(next_endpoint_level)] = x
self._output_specs = {l: endpoints[l].get_shape() for l in endpoints} self._output_specs = {l: endpoints[l].get_shape() for l in endpoints}
super(MobileNet, self).__init__( super(MobileNet, self).__init__(
...@@ -568,7 +604,7 @@ class MobileNet(tf.keras.Model): ...@@ -568,7 +604,7 @@ class MobileNet(tf.keras.Model):
def _mobilenet_base(self, def _mobilenet_base(self,
inputs: tf.Tensor inputs: tf.Tensor
) -> Tuple[tf.Tensor, Dict[int, tf.Tensor]]: ) -> Tuple[tf.Tensor, Dict[str, tf.Tensor], int]:
"""Builds the base MobileNet architecture. """Builds the base MobileNet architecture.
Args: Args:
...@@ -594,7 +630,7 @@ class MobileNet(tf.keras.Model): ...@@ -594,7 +630,7 @@ class MobileNet(tf.keras.Model):
net = inputs net = inputs
endpoints = {} endpoints = {}
endpoint_level = 1 endpoint_level = 2
for i, block_def in enumerate(self._decoded_specs): for i, block_def in enumerate(self._decoded_specs):
block_name = 'block_group_{}_{}'.format(block_def.block_fn, i) block_name = 'block_group_{}_{}'.format(block_def.block_fn, i)
# A small catch for gpooling block with None strides # A small catch for gpooling block with None strides
...@@ -688,10 +724,13 @@ class MobileNet(tf.keras.Model): ...@@ -688,10 +724,13 @@ class MobileNet(tf.keras.Model):
raise ValueError('Unknown block type {} for layer {}'.format( raise ValueError('Unknown block type {} for layer {}'.format(
block_def.block_fn, i)) block_def.block_fn, i))
endpoints[endpoint_level] = net
endpoint_level += 1
net = tf.identity(net, name=block_name) net = tf.identity(net, name=block_name)
return net, endpoints
if block_def.is_output:
endpoints[str(endpoint_level)] = net
endpoint_level += 1
return net, endpoints, endpoint_level
def get_config(self): def get_config(self):
config_dict = { config_dict = {
......
# Lint as: python3 # Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -12,7 +11,8 @@ ...@@ -12,7 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ==============================================================================
# Lint as: python3
"""Tests for MobileNet.""" """Tests for MobileNet."""
import itertools import itertools
...@@ -109,27 +109,27 @@ class MobileNetTest(parameterized.TestCase, tf.test.TestCase): ...@@ -109,27 +109,27 @@ class MobileNetTest(parameterized.TestCase, tf.test.TestCase):
tf.keras.backend.set_image_data_format('channels_last') tf.keras.backend.set_image_data_format('channels_last')
mobilenet_layers = { mobilenet_layers = {
# The stride (relative to input) and number of filters # The number of filters of layers having outputs been collected
# of first few layers for filter_size_scale = 0.75 # for filter_size_scale = 1.0
'MobileNetV1': [(1, 24), (1, 48), (2, 96), (2, 96)], 'MobileNetV1': [128, 256, 512, 1024],
'MobileNetV2': [(1, 24), (1, 16), (2, 24), (2, 24)], 'MobileNetV2': [24, 32, 96, 320],
'MobileNetV3Small': [(1, 16), (2, 16), (3, 24), (3, 24)], 'MobileNetV3Small': [16, 24, 48, 96],
'MobileNetV3Large': [(1, 16), (1, 16), (2, 24), (2, 24)], 'MobileNetV3Large': [24, 40, 112, 160],
'MobileNetV3EdgeTPU': [(1, 24), (1, 16), (2, 24), (2, 24)], 'MobileNetV3EdgeTPU': [32, 48, 96, 192],
'MobileNetMultiMAX': [(1, 24), (2, 24), (3, 48), (3, 48)], 'MobileNetMultiMAX': [32, 64, 128, 160],
'MobileNetMultiAVG': [(1, 24), (2, 24), (2, 24), (3, 48)], 'MobileNetMultiAVG': [32, 64, 160, 192],
} }
network = mobilenet.MobileNet(model_id=model_id, network = mobilenet.MobileNet(model_id=model_id,
filter_size_scale=0.75) filter_size_scale=1.0)
inputs = tf.keras.Input(shape=(input_size, input_size, 3), batch_size=1) inputs = tf.keras.Input(shape=(input_size, input_size, 3), batch_size=1)
endpoints = network(inputs) endpoints = network(inputs)
for idx, (stride, num_filter) in enumerate(mobilenet_layers[model_id]): for idx, num_filter in enumerate(mobilenet_layers[model_id]):
self.assertAllEqual( self.assertAllEqual(
[1, input_size / 2 ** stride, input_size / 2 ** stride, num_filter], [1, input_size / 2 ** (idx+2), input_size / 2 ** (idx+2), num_filter],
endpoints[idx+1].shape.as_list()) endpoints[str(idx+2)].shape.as_list())
@parameterized.parameters( @parameterized.parameters(
itertools.product( itertools.product(
......
# Copyright 2020 The TensorFlow Authors. All Rights Reserved. # Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ==============================================================================
"""Contains definitions of Residual Networks.""" """Contains definitions of Residual Networks."""
# Import libraries # Import libraries
......
# Copyright 2020 The TensorFlow Authors. All Rights Reserved. # Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ==============================================================================
"""Contains definitions of 3D Residual Networks.""" """Contains definitions of 3D Residual Networks."""
from typing import List, Tuple from typing import List, Tuple
...@@ -163,7 +163,7 @@ class ResNet3D(tf.keras.Model): ...@@ -163,7 +163,7 @@ class ResNet3D(tf.keras.Model):
block_repeats=resnet_spec[2], block_repeats=resnet_spec[2],
use_self_gating=use_self_gating[i] if use_self_gating else False, use_self_gating=use_self_gating[i] if use_self_gating else False,
name='block_group_l{}'.format(i + 2)) name='block_group_l{}'.format(i + 2))
endpoints[i + 2] = x endpoints[str(i + 2)] = x
self._output_specs = {l: endpoints[l].get_shape() for l in endpoints} self._output_specs = {l: endpoints[l].get_shape() for l in endpoints}
......
# Lint as: python3 # Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -12,7 +11,8 @@ ...@@ -12,7 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ==============================================================================
# Lint as: python3
"""Tests for resnet.""" """Tests for resnet."""
# Import libraries # Import libraries
...@@ -47,16 +47,16 @@ class ResNet3DTest(parameterized.TestCase, tf.test.TestCase): ...@@ -47,16 +47,16 @@ class ResNet3DTest(parameterized.TestCase, tf.test.TestCase):
self.assertAllEqual([ self.assertAllEqual([
1, 2, input_size / 2**2, input_size / 2**2, 64 * endpoint_filter_scale 1, 2, input_size / 2**2, input_size / 2**2, 64 * endpoint_filter_scale
], endpoints[2].shape.as_list()) ], endpoints['2'].shape.as_list())
self.assertAllEqual([ self.assertAllEqual([
1, 2, input_size / 2**3, input_size / 2**3, 128 * endpoint_filter_scale 1, 2, input_size / 2**3, input_size / 2**3, 128 * endpoint_filter_scale
], endpoints[3].shape.as_list()) ], endpoints['3'].shape.as_list())
self.assertAllEqual([ self.assertAllEqual([
1, 2, input_size / 2**4, input_size / 2**4, 256 * endpoint_filter_scale 1, 2, input_size / 2**4, input_size / 2**4, 256 * endpoint_filter_scale
], endpoints[4].shape.as_list()) ], endpoints['4'].shape.as_list())
self.assertAllEqual([ self.assertAllEqual([
1, 2, input_size / 2**5, input_size / 2**5, 512 * endpoint_filter_scale 1, 2, input_size / 2**5, input_size / 2**5, 512 * endpoint_filter_scale
], endpoints[5].shape.as_list()) ], endpoints['5'].shape.as_list())
def test_serialize_deserialize(self): def test_serialize_deserialize(self):
# Create a network object that sets all of its config options. # Create a network object that sets all of its config options.
......
# Copyright 2020 The TensorFlow Authors. All Rights Reserved. # Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ==============================================================================
"""Contains definitions of Residual Networks with Deeplab modifications.""" """Contains definitions of Residual Networks with Deeplab modifications."""
import numpy as np import numpy as np
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment