Unverified Commit 8b641b13 authored by Srihari Humbarwadi's avatar Srihari Humbarwadi Committed by GitHub
Browse files

Merge branch 'tensorflow:master' into panoptic-deeplab

parents 7cffacfe 357fa547
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# Lint as: python3
r"""Training driver. r"""Training driver.
To train: To train:
......
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# Lint as: python3
"""Tests for train.py.""" """Tests for train.py."""
import json import json
......
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# Lint as: python3
"""Configs package definition.""" """Configs package definition."""
from official.projects.pruning.configs import image_classification from official.projects.pruning.configs import image_classification
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# Lint as: python3
"""Image classification configuration definition.""" """Image classification configuration definition."""
import dataclasses import dataclasses
......
...@@ -12,16 +12,15 @@ ...@@ -12,16 +12,15 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# Lint as: python3
"""Tests for image_classification.""" """Tests for image_classification."""
# pylint: disable=unused-import # pylint: disable=unused-import
from absl.testing import parameterized from absl.testing import parameterized
import tensorflow as tf import tensorflow as tf
from official import vision
from official.core import config_definitions as cfg from official.core import config_definitions as cfg
from official.core import exp_factory from official.core import exp_factory
from official.projects.pruning.configs import image_classification as pruning_exp_cfg from official.projects.pruning.configs import image_classification as pruning_exp_cfg
from official.vision import beta
from official.vision.configs import image_classification as exp_cfg from official.vision.configs import image_classification as exp_cfg
......
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# Lint as: python3
"""Modeling package definition.""" """Modeling package definition."""
from official.projects.pruning.tasks import image_classification from official.projects.pruning.tasks import image_classification
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# Lint as: python3
"""Image classification task definition.""" """Image classification task definition."""
from absl import logging from absl import logging
import tensorflow as tf import tensorflow as tf
......
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# Lint as: python3
"""Tests for image classification task.""" """Tests for image classification task."""
# pylint: disable=unused-import # pylint: disable=unused-import
...@@ -22,13 +21,13 @@ from absl.testing import parameterized ...@@ -22,13 +21,13 @@ from absl.testing import parameterized
import numpy as np import numpy as np
import orbit import orbit
import tensorflow as tf import tensorflow as tf
import tensorflow_model_optimization as tfmot
import tensorflow_model_optimization as tfmot
from official import vision
from official.core import actions from official.core import actions
from official.core import exp_factory from official.core import exp_factory
from official.modeling import optimization from official.modeling import optimization
from official.projects.pruning.tasks import image_classification as img_cls_task from official.projects.pruning.tasks import image_classification as img_cls_task
from official.vision import beta
class ImageClassificationTaskTest(tf.test.TestCase, parameterized.TestCase): class ImageClassificationTaskTest(tf.test.TestCase, parameterized.TestCase):
......
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# Lint as: python3
"""Configs package definition.""" """Configs package definition."""
from official.projects.qat.vision.configs import image_classification from official.projects.qat.vision.configs import image_classification
......
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# Lint as: python3
"""Image classification configuration definition.""" """Image classification configuration definition."""
import dataclasses import dataclasses
......
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# Lint as: python3
"""Image classification configuration definition.""" """Image classification configuration definition."""
import dataclasses import dataclasses
......
...@@ -17,11 +17,11 @@ ...@@ -17,11 +17,11 @@
from absl.testing import parameterized from absl.testing import parameterized
import tensorflow as tf import tensorflow as tf
from official import vision
from official.core import config_definitions as cfg from official.core import config_definitions as cfg
from official.core import exp_factory from official.core import exp_factory
from official.projects.qat.vision.configs import common from official.projects.qat.vision.configs import common
from official.projects.qat.vision.configs import image_classification as qat_exp_cfg from official.projects.qat.vision.configs import image_classification as qat_exp_cfg
from official.vision import beta
from official.vision.configs import image_classification as exp_cfg from official.vision.configs import image_classification as exp_cfg
......
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# Lint as: python3
"""RetinaNet configuration definition.""" """RetinaNet configuration definition."""
import dataclasses import dataclasses
from typing import Optional from typing import Optional
...@@ -21,7 +20,7 @@ from official.core import config_definitions as cfg ...@@ -21,7 +20,7 @@ from official.core import config_definitions as cfg
from official.core import exp_factory from official.core import exp_factory
from official.projects.qat.vision.configs import common from official.projects.qat.vision.configs import common
from official.vision.configs import retinanet from official.vision.configs import retinanet
from official.vision.configs.google import backbones from official.vision.configs import backbones
@dataclasses.dataclass @dataclasses.dataclass
......
...@@ -17,11 +17,11 @@ ...@@ -17,11 +17,11 @@
from absl.testing import parameterized from absl.testing import parameterized
import tensorflow as tf import tensorflow as tf
from official import vision
from official.core import config_definitions as cfg from official.core import config_definitions as cfg
from official.core import exp_factory from official.core import exp_factory
from official.projects.qat.vision.configs import common from official.projects.qat.vision.configs import common
from official.projects.qat.vision.configs import retinanet as qat_exp_cfg from official.projects.qat.vision.configs import retinanet as qat_exp_cfg
from official.vision import beta
from official.vision.configs import retinanet as exp_cfg from official.vision.configs import retinanet as exp_cfg
......
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# Lint as: python3
"""RetinaNet configuration definition.""" """RetinaNet configuration definition."""
import dataclasses import dataclasses
from typing import Optional from typing import Optional
......
...@@ -17,11 +17,11 @@ ...@@ -17,11 +17,11 @@
from absl.testing import parameterized from absl.testing import parameterized
import tensorflow as tf import tensorflow as tf
from official import vision
from official.core import config_definitions as cfg from official.core import config_definitions as cfg
from official.core import exp_factory from official.core import exp_factory
from official.projects.qat.vision.configs import common from official.projects.qat.vision.configs import common
from official.projects.qat.vision.configs import semantic_segmentation as qat_exp_cfg from official.projects.qat.vision.configs import semantic_segmentation as qat_exp_cfg
from official.vision import beta
from official.vision.configs import semantic_segmentation as exp_cfg from official.vision.configs import semantic_segmentation as exp_cfg
......
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# Lint as: python3
"""Modeling package definition.""" """Modeling package definition."""
from official.projects.qat.vision.modeling import layers from official.projects.qat.vision.modeling import layers
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# Lint as: python3
"""Layers package definition.""" """Layers package definition."""
from official.projects.qat.vision.modeling.layers.nn_blocks import BottleneckBlockQuantized from official.projects.qat.vision.modeling.layers.nn_blocks import BottleneckBlockQuantized
......
...@@ -24,42 +24,10 @@ import tensorflow_model_optimization as tfmot ...@@ -24,42 +24,10 @@ import tensorflow_model_optimization as tfmot
from official.modeling import tf_utils from official.modeling import tf_utils
from official.projects.qat.vision.modeling.layers import nn_layers as qat_nn_layers from official.projects.qat.vision.modeling.layers import nn_layers as qat_nn_layers
from official.projects.qat.vision.quantization import configs from official.projects.qat.vision.quantization import configs
from official.projects.qat.vision.quantization import helper
from official.vision.modeling.layers import nn_layers from official.vision.modeling.layers import nn_layers
class NoOpActivation:
"""No-op activation which simply returns the incoming tensor.
This activation is required to distinguish between `keras.activations.linear`
which does the same thing. The main difference is that NoOpActivation should
not have any quantize operation applied to it.
"""
def __call__(self, x: tf.Tensor) -> tf.Tensor:
return x
def get_config(self) -> Dict[str, Any]:
"""Get a config of this object."""
return {}
def __eq__(self, other: Any) -> bool:
if not other or not isinstance(other, NoOpActivation):
return False
return True
def __ne__(self, other: Any) -> bool:
return not self.__eq__(other)
def _quantize_wrapped_layer(cls, quantize_config):
def constructor(*arg, **kwargs):
return tfmot.quantization.keras.QuantizeWrapperV2(
cls(*arg, **kwargs),
quantize_config)
return constructor
# This class is copied from modeling.layers.nn_blocks.BottleneckBlock and apply # This class is copied from modeling.layers.nn_blocks.BottleneckBlock and apply
# QAT. # QAT.
@tf.keras.utils.register_keras_serializable(package='Vision') @tf.keras.utils.register_keras_serializable(package='Vision')
...@@ -131,17 +99,16 @@ class BottleneckBlockQuantized(tf.keras.layers.Layer): ...@@ -131,17 +99,16 @@ class BottleneckBlockQuantized(tf.keras.layers.Layer):
self._kernel_regularizer = kernel_regularizer self._kernel_regularizer = kernel_regularizer
self._bias_regularizer = bias_regularizer self._bias_regularizer = bias_regularizer
if use_sync_bn: if use_sync_bn:
self._norm = _quantize_wrapped_layer( self._norm = helper.quantize_wrapped_layer(
tf.keras.layers.experimental.SyncBatchNormalization, tf.keras.layers.experimental.SyncBatchNormalization,
configs.NoOpQuantizeConfig()) configs.NoOpQuantizeConfig())
self._norm_with_quantize = _quantize_wrapped_layer( self._norm_with_quantize = helper.quantize_wrapped_layer(
tf.keras.layers.experimental.SyncBatchNormalization, tf.keras.layers.experimental.SyncBatchNormalization,
configs.Default8BitOutputQuantizeConfig()) configs.Default8BitOutputQuantizeConfig())
else: else:
self._norm = _quantize_wrapped_layer( self._norm = helper.quantize_wrapped_layer(
tf.keras.layers.BatchNormalization, tf.keras.layers.BatchNormalization, configs.NoOpQuantizeConfig())
configs.NoOpQuantizeConfig()) self._norm_with_quantize = helper.quantize_wrapped_layer(
self._norm_with_quantize = _quantize_wrapped_layer(
tf.keras.layers.BatchNormalization, tf.keras.layers.BatchNormalization,
configs.Default8BitOutputQuantizeConfig()) configs.Default8BitOutputQuantizeConfig())
if tf.keras.backend.image_data_format() == 'channels_last': if tf.keras.backend.image_data_format() == 'channels_last':
...@@ -152,10 +119,10 @@ class BottleneckBlockQuantized(tf.keras.layers.Layer): ...@@ -152,10 +119,10 @@ class BottleneckBlockQuantized(tf.keras.layers.Layer):
def build(self, input_shape: Optional[Union[Sequence[int], tf.Tensor]]): def build(self, input_shape: Optional[Union[Sequence[int], tf.Tensor]]):
"""Build variables and child layers to prepare for calling.""" """Build variables and child layers to prepare for calling."""
conv2d_quantized = _quantize_wrapped_layer( conv2d_quantized = helper.quantize_wrapped_layer(
tf.keras.layers.Conv2D, tf.keras.layers.Conv2D,
configs.Default8BitConvQuantizeConfig( configs.Default8BitConvQuantizeConfig(['kernel'], ['activation'],
['kernel'], ['activation'], False)) False))
if self._use_projection: if self._use_projection:
if self._resnetd_shortcut: if self._resnetd_shortcut:
self._shortcut0 = tf.keras.layers.AveragePooling2D( self._shortcut0 = tf.keras.layers.AveragePooling2D(
...@@ -168,7 +135,7 @@ class BottleneckBlockQuantized(tf.keras.layers.Layer): ...@@ -168,7 +135,7 @@ class BottleneckBlockQuantized(tf.keras.layers.Layer):
kernel_initializer=self._kernel_initializer, kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer, kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer, bias_regularizer=self._bias_regularizer,
activation=NoOpActivation()) activation=helper.NoOpActivation())
else: else:
self._shortcut = conv2d_quantized( self._shortcut = conv2d_quantized(
filters=self._filters * 4, filters=self._filters * 4,
...@@ -178,7 +145,7 @@ class BottleneckBlockQuantized(tf.keras.layers.Layer): ...@@ -178,7 +145,7 @@ class BottleneckBlockQuantized(tf.keras.layers.Layer):
kernel_initializer=self._kernel_initializer, kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer, kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer, bias_regularizer=self._bias_regularizer,
activation=NoOpActivation()) activation=helper.NoOpActivation())
self._norm0 = self._norm_with_quantize( self._norm0 = self._norm_with_quantize(
axis=self._bn_axis, axis=self._bn_axis,
...@@ -194,7 +161,7 @@ class BottleneckBlockQuantized(tf.keras.layers.Layer): ...@@ -194,7 +161,7 @@ class BottleneckBlockQuantized(tf.keras.layers.Layer):
kernel_initializer=self._kernel_initializer, kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer, kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer, bias_regularizer=self._bias_regularizer,
activation=NoOpActivation()) activation=helper.NoOpActivation())
self._norm1 = self._norm( self._norm1 = self._norm(
axis=self._bn_axis, axis=self._bn_axis,
momentum=self._norm_momentum, momentum=self._norm_momentum,
...@@ -214,7 +181,7 @@ class BottleneckBlockQuantized(tf.keras.layers.Layer): ...@@ -214,7 +181,7 @@ class BottleneckBlockQuantized(tf.keras.layers.Layer):
kernel_initializer=self._kernel_initializer, kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer, kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer, bias_regularizer=self._bias_regularizer,
activation=NoOpActivation()) activation=helper.NoOpActivation())
self._norm2 = self._norm( self._norm2 = self._norm(
axis=self._bn_axis, axis=self._bn_axis,
momentum=self._norm_momentum, momentum=self._norm_momentum,
...@@ -232,7 +199,7 @@ class BottleneckBlockQuantized(tf.keras.layers.Layer): ...@@ -232,7 +199,7 @@ class BottleneckBlockQuantized(tf.keras.layers.Layer):
kernel_initializer=self._kernel_initializer, kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer, kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer, bias_regularizer=self._bias_regularizer,
activation=NoOpActivation()) activation=helper.NoOpActivation())
self._norm3 = self._norm_with_quantize( self._norm3 = self._norm_with_quantize(
axis=self._bn_axis, axis=self._bn_axis,
momentum=self._norm_momentum, momentum=self._norm_momentum,
...@@ -392,10 +359,10 @@ class Conv2DBNBlockQuantized(tf.keras.layers.Layer): ...@@ -392,10 +359,10 @@ class Conv2DBNBlockQuantized(tf.keras.layers.Layer):
norm_layer = ( norm_layer = (
tf.keras.layers.experimental.SyncBatchNormalization tf.keras.layers.experimental.SyncBatchNormalization
if use_sync_bn else tf.keras.layers.BatchNormalization) if use_sync_bn else tf.keras.layers.BatchNormalization)
self._norm_with_quantize = _quantize_wrapped_layer( self._norm_with_quantize = helper.quantize_wrapped_layer(
norm_layer, configs.Default8BitOutputQuantizeConfig()) norm_layer, configs.Default8BitOutputQuantizeConfig())
self._norm = _quantize_wrapped_layer(norm_layer, self._norm = helper.quantize_wrapped_layer(norm_layer,
configs.NoOpQuantizeConfig()) configs.NoOpQuantizeConfig())
if tf.keras.backend.image_data_format() == 'channels_last': if tf.keras.backend.image_data_format() == 'channels_last':
self._bn_axis = -1 self._bn_axis = -1
...@@ -432,10 +399,10 @@ class Conv2DBNBlockQuantized(tf.keras.layers.Layer): ...@@ -432,10 +399,10 @@ class Conv2DBNBlockQuantized(tf.keras.layers.Layer):
if self._use_explicit_padding and self._kernel_size > 1: if self._use_explicit_padding and self._kernel_size > 1:
padding_size = nn_layers.get_padding_for_kernel_size(self._kernel_size) padding_size = nn_layers.get_padding_for_kernel_size(self._kernel_size)
self._pad = tf.keras.layers.ZeroPadding2D(padding_size) self._pad = tf.keras.layers.ZeroPadding2D(padding_size)
conv2d_quantized = _quantize_wrapped_layer( conv2d_quantized = helper.quantize_wrapped_layer(
tf.keras.layers.Conv2D, tf.keras.layers.Conv2D,
configs.Default8BitConvQuantizeConfig( configs.Default8BitConvQuantizeConfig(['kernel'], ['activation'],
['kernel'], ['activation'], not self._use_normalization)) not self._use_normalization))
self._conv0 = conv2d_quantized( self._conv0 = conv2d_quantized(
filters=self._filters, filters=self._filters,
kernel_size=self._kernel_size, kernel_size=self._kernel_size,
...@@ -445,7 +412,7 @@ class Conv2DBNBlockQuantized(tf.keras.layers.Layer): ...@@ -445,7 +412,7 @@ class Conv2DBNBlockQuantized(tf.keras.layers.Layer):
kernel_initializer=self._kernel_initializer, kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer, kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer, bias_regularizer=self._bias_regularizer,
activation=NoOpActivation()) activation=helper.NoOpActivation())
if self._use_normalization: if self._use_normalization:
self._norm0 = self._norm_by_activation(self._activation)( self._norm0 = self._norm_by_activation(self._activation)(
axis=self._bn_axis, axis=self._bn_axis,
...@@ -579,10 +546,10 @@ class InvertedBottleneckBlockQuantized(tf.keras.layers.Layer): ...@@ -579,10 +546,10 @@ class InvertedBottleneckBlockQuantized(tf.keras.layers.Layer):
norm_layer = ( norm_layer = (
tf.keras.layers.experimental.SyncBatchNormalization tf.keras.layers.experimental.SyncBatchNormalization
if use_sync_bn else tf.keras.layers.BatchNormalization) if use_sync_bn else tf.keras.layers.BatchNormalization)
self._norm_with_quantize = _quantize_wrapped_layer( self._norm_with_quantize = helper.quantize_wrapped_layer(
norm_layer, configs.Default8BitOutputQuantizeConfig()) norm_layer, configs.Default8BitOutputQuantizeConfig())
self._norm = _quantize_wrapped_layer(norm_layer, self._norm = helper.quantize_wrapped_layer(norm_layer,
configs.NoOpQuantizeConfig()) configs.NoOpQuantizeConfig())
if tf.keras.backend.image_data_format() == 'channels_last': if tf.keras.backend.image_data_format() == 'channels_last':
self._bn_axis = -1 self._bn_axis = -1
...@@ -602,14 +569,14 @@ class InvertedBottleneckBlockQuantized(tf.keras.layers.Layer): ...@@ -602,14 +569,14 @@ class InvertedBottleneckBlockQuantized(tf.keras.layers.Layer):
def build(self, input_shape: Optional[Union[Sequence[int], tf.Tensor]]): def build(self, input_shape: Optional[Union[Sequence[int], tf.Tensor]]):
"""Build variables and child layers to prepare for calling.""" """Build variables and child layers to prepare for calling."""
conv2d_quantized = _quantize_wrapped_layer( conv2d_quantized = helper.quantize_wrapped_layer(
tf.keras.layers.Conv2D, tf.keras.layers.Conv2D,
configs.Default8BitConvQuantizeConfig( configs.Default8BitConvQuantizeConfig(['kernel'], ['activation'],
['kernel'], ['activation'], False)) False))
depthwise_conv2d_quantized = _quantize_wrapped_layer( depthwise_conv2d_quantized = helper.quantize_wrapped_layer(
tf.keras.layers.DepthwiseConv2D, tf.keras.layers.DepthwiseConv2D,
configs.Default8BitConvQuantizeConfig( configs.Default8BitConvQuantizeConfig(['depthwise_kernel'],
['depthwise_kernel'], ['activation'], False)) ['activation'], False))
expand_filters = self._in_filters expand_filters = self._in_filters
if self._expand_ratio > 1: if self._expand_ratio > 1:
# First 1x1 conv for channel expansion. # First 1x1 conv for channel expansion.
...@@ -628,7 +595,7 @@ class InvertedBottleneckBlockQuantized(tf.keras.layers.Layer): ...@@ -628,7 +595,7 @@ class InvertedBottleneckBlockQuantized(tf.keras.layers.Layer):
kernel_initializer=self._kernel_initializer, kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer, kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer, bias_regularizer=self._bias_regularizer,
activation=NoOpActivation()) activation=helper.NoOpActivation())
self._norm0 = self._norm_by_activation(self._activation)( self._norm0 = self._norm_by_activation(self._activation)(
axis=self._bn_axis, axis=self._bn_axis,
momentum=self._norm_momentum, momentum=self._norm_momentum,
...@@ -649,7 +616,7 @@ class InvertedBottleneckBlockQuantized(tf.keras.layers.Layer): ...@@ -649,7 +616,7 @@ class InvertedBottleneckBlockQuantized(tf.keras.layers.Layer):
depthwise_initializer=self._kernel_initializer, depthwise_initializer=self._kernel_initializer,
depthwise_regularizer=self._depthsize_regularizer, depthwise_regularizer=self._depthsize_regularizer,
bias_regularizer=self._bias_regularizer, bias_regularizer=self._bias_regularizer,
activation=NoOpActivation()) activation=helper.NoOpActivation())
self._norm1 = self._norm_by_activation(self._depthwise_activation)( self._norm1 = self._norm_by_activation(self._depthwise_activation)(
axis=self._bn_axis, axis=self._bn_axis,
momentum=self._norm_momentum, momentum=self._norm_momentum,
...@@ -690,7 +657,7 @@ class InvertedBottleneckBlockQuantized(tf.keras.layers.Layer): ...@@ -690,7 +657,7 @@ class InvertedBottleneckBlockQuantized(tf.keras.layers.Layer):
kernel_initializer=self._kernel_initializer, kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer, kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer, bias_regularizer=self._bias_regularizer,
activation=NoOpActivation()) activation=helper.NoOpActivation())
self._norm2 = self._norm_with_quantize( self._norm2 = self._norm_with_quantize(
axis=self._bn_axis, axis=self._bn_axis,
momentum=self._norm_momentum, momentum=self._norm_momentum,
......
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# Lint as: python3
"""Tests for nn_blocks.""" """Tests for nn_blocks."""
from typing import Any, Iterable, Tuple from typing import Any, Iterable, Tuple
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment