Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
8b641b13
Unverified
Commit
8b641b13
authored
Mar 26, 2022
by
Srihari Humbarwadi
Committed by
GitHub
Mar 26, 2022
Browse files
Merge branch 'tensorflow:master' into panoptic-deeplab
parents
7cffacfe
357fa547
Changes
503
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
40 additions
and
89 deletions
+40
-89
official/projects/movinet/train.py
official/projects/movinet/train.py
+0
-1
official/projects/movinet/train_test.py
official/projects/movinet/train_test.py
+0
-1
official/projects/pruning/configs/__init__.py
official/projects/pruning/configs/__init__.py
+0
-1
official/projects/pruning/configs/image_classification.py
official/projects/pruning/configs/image_classification.py
+0
-1
official/projects/pruning/configs/image_classification_test.py
...ial/projects/pruning/configs/image_classification_test.py
+1
-2
official/projects/pruning/tasks/__init__.py
official/projects/pruning/tasks/__init__.py
+0
-1
official/projects/pruning/tasks/image_classification.py
official/projects/pruning/tasks/image_classification.py
+0
-1
official/projects/pruning/tasks/image_classification_test.py
official/projects/pruning/tasks/image_classification_test.py
+2
-3
official/projects/qat/vision/configs/__init__.py
official/projects/qat/vision/configs/__init__.py
+0
-1
official/projects/qat/vision/configs/common.py
official/projects/qat/vision/configs/common.py
+0
-1
official/projects/qat/vision/configs/image_classification.py
official/projects/qat/vision/configs/image_classification.py
+0
-1
official/projects/qat/vision/configs/image_classification_test.py
.../projects/qat/vision/configs/image_classification_test.py
+1
-1
official/projects/qat/vision/configs/retinanet.py
official/projects/qat/vision/configs/retinanet.py
+1
-2
official/projects/qat/vision/configs/retinanet_test.py
official/projects/qat/vision/configs/retinanet_test.py
+1
-1
official/projects/qat/vision/configs/semantic_segmentation.py
...cial/projects/qat/vision/configs/semantic_segmentation.py
+0
-1
official/projects/qat/vision/configs/semantic_segmentation_test.py
...projects/qat/vision/configs/semantic_segmentation_test.py
+1
-1
official/projects/qat/vision/modeling/__init__.py
official/projects/qat/vision/modeling/__init__.py
+0
-1
official/projects/qat/vision/modeling/layers/__init__.py
official/projects/qat/vision/modeling/layers/__init__.py
+0
-1
official/projects/qat/vision/modeling/layers/nn_blocks.py
official/projects/qat/vision/modeling/layers/nn_blocks.py
+33
-66
official/projects/qat/vision/modeling/layers/nn_blocks_test.py
...ial/projects/qat/vision/modeling/layers/nn_blocks_test.py
+0
-1
No files found.
official/projects/movinet/train.py
View file @
8b641b13
...
@@ -12,7 +12,6 @@
...
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# Lint as: python3
r
"""Training driver.
r
"""Training driver.
To train:
To train:
...
...
official/projects/movinet/train_test.py
View file @
8b641b13
...
@@ -12,7 +12,6 @@
...
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# Lint as: python3
"""Tests for train.py."""
"""Tests for train.py."""
import
json
import
json
...
...
official/projects/pruning/configs/__init__.py
View file @
8b641b13
...
@@ -12,7 +12,6 @@
...
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# Lint as: python3
"""Configs package definition."""
"""Configs package definition."""
from
official.projects.pruning.configs
import
image_classification
from
official.projects.pruning.configs
import
image_classification
official/projects/pruning/configs/image_classification.py
View file @
8b641b13
...
@@ -12,7 +12,6 @@
...
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# Lint as: python3
"""Image classification configuration definition."""
"""Image classification configuration definition."""
import
dataclasses
import
dataclasses
...
...
official/projects/pruning/configs/image_classification_test.py
View file @
8b641b13
...
@@ -12,16 +12,15 @@
...
@@ -12,16 +12,15 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# Lint as: python3
"""Tests for image_classification."""
"""Tests for image_classification."""
# pylint: disable=unused-import
# pylint: disable=unused-import
from
absl.testing
import
parameterized
from
absl.testing
import
parameterized
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official
import
vision
from
official.core
import
config_definitions
as
cfg
from
official.core
import
config_definitions
as
cfg
from
official.core
import
exp_factory
from
official.core
import
exp_factory
from
official.projects.pruning.configs
import
image_classification
as
pruning_exp_cfg
from
official.projects.pruning.configs
import
image_classification
as
pruning_exp_cfg
from
official.vision
import
beta
from
official.vision.configs
import
image_classification
as
exp_cfg
from
official.vision.configs
import
image_classification
as
exp_cfg
...
...
official/projects/pruning/tasks/__init__.py
View file @
8b641b13
...
@@ -12,7 +12,6 @@
...
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# Lint as: python3
"""Modeling package definition."""
"""Modeling package definition."""
from
official.projects.pruning.tasks
import
image_classification
from
official.projects.pruning.tasks
import
image_classification
official/projects/pruning/tasks/image_classification.py
View file @
8b641b13
...
@@ -12,7 +12,6 @@
...
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# Lint as: python3
"""Image classification task definition."""
"""Image classification task definition."""
from
absl
import
logging
from
absl
import
logging
import
tensorflow
as
tf
import
tensorflow
as
tf
...
...
official/projects/pruning/tasks/image_classification_test.py
View file @
8b641b13
...
@@ -12,7 +12,6 @@
...
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# Lint as: python3
"""Tests for image classification task."""
"""Tests for image classification task."""
# pylint: disable=unused-import
# pylint: disable=unused-import
...
@@ -22,13 +21,13 @@ from absl.testing import parameterized
...
@@ -22,13 +21,13 @@ from absl.testing import parameterized
import
numpy
as
np
import
numpy
as
np
import
orbit
import
orbit
import
tensorflow
as
tf
import
tensorflow
as
tf
import
tensorflow_model_optimization
as
tfmot
import
tensorflow_model_optimization
as
tfmot
from
official
import
vision
from
official.core
import
actions
from
official.core
import
actions
from
official.core
import
exp_factory
from
official.core
import
exp_factory
from
official.modeling
import
optimization
from
official.modeling
import
optimization
from
official.projects.pruning.tasks
import
image_classification
as
img_cls_task
from
official.projects.pruning.tasks
import
image_classification
as
img_cls_task
from
official.vision
import
beta
class
ImageClassificationTaskTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
class
ImageClassificationTaskTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
...
...
official/projects/qat/vision/configs/__init__.py
View file @
8b641b13
...
@@ -12,7 +12,6 @@
...
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# Lint as: python3
"""Configs package definition."""
"""Configs package definition."""
from
official.projects.qat.vision.configs
import
image_classification
from
official.projects.qat.vision.configs
import
image_classification
...
...
official/projects/qat/vision/configs/common.py
View file @
8b641b13
...
@@ -12,7 +12,6 @@
...
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# Lint as: python3
"""Image classification configuration definition."""
"""Image classification configuration definition."""
import
dataclasses
import
dataclasses
...
...
official/projects/qat/vision/configs/image_classification.py
View file @
8b641b13
...
@@ -12,7 +12,6 @@
...
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# Lint as: python3
"""Image classification configuration definition."""
"""Image classification configuration definition."""
import
dataclasses
import
dataclasses
...
...
official/projects/qat/vision/configs/image_classification_test.py
View file @
8b641b13
...
@@ -17,11 +17,11 @@
...
@@ -17,11 +17,11 @@
from
absl.testing
import
parameterized
from
absl.testing
import
parameterized
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official
import
vision
from
official.core
import
config_definitions
as
cfg
from
official.core
import
config_definitions
as
cfg
from
official.core
import
exp_factory
from
official.core
import
exp_factory
from
official.projects.qat.vision.configs
import
common
from
official.projects.qat.vision.configs
import
common
from
official.projects.qat.vision.configs
import
image_classification
as
qat_exp_cfg
from
official.projects.qat.vision.configs
import
image_classification
as
qat_exp_cfg
from
official.vision
import
beta
from
official.vision.configs
import
image_classification
as
exp_cfg
from
official.vision.configs
import
image_classification
as
exp_cfg
...
...
official/projects/qat/vision/configs/retinanet.py
View file @
8b641b13
...
@@ -12,7 +12,6 @@
...
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# Lint as: python3
"""RetinaNet configuration definition."""
"""RetinaNet configuration definition."""
import
dataclasses
import
dataclasses
from
typing
import
Optional
from
typing
import
Optional
...
@@ -21,7 +20,7 @@ from official.core import config_definitions as cfg
...
@@ -21,7 +20,7 @@ from official.core import config_definitions as cfg
from
official.core
import
exp_factory
from
official.core
import
exp_factory
from
official.projects.qat.vision.configs
import
common
from
official.projects.qat.vision.configs
import
common
from
official.vision.configs
import
retinanet
from
official.vision.configs
import
retinanet
from
official.vision.configs
.google
import
backbones
from
official.vision.configs
import
backbones
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
...
...
official/projects/qat/vision/configs/retinanet_test.py
View file @
8b641b13
...
@@ -17,11 +17,11 @@
...
@@ -17,11 +17,11 @@
from
absl.testing
import
parameterized
from
absl.testing
import
parameterized
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official
import
vision
from
official.core
import
config_definitions
as
cfg
from
official.core
import
config_definitions
as
cfg
from
official.core
import
exp_factory
from
official.core
import
exp_factory
from
official.projects.qat.vision.configs
import
common
from
official.projects.qat.vision.configs
import
common
from
official.projects.qat.vision.configs
import
retinanet
as
qat_exp_cfg
from
official.projects.qat.vision.configs
import
retinanet
as
qat_exp_cfg
from
official.vision
import
beta
from
official.vision.configs
import
retinanet
as
exp_cfg
from
official.vision.configs
import
retinanet
as
exp_cfg
...
...
official/projects/qat/vision/configs/semantic_segmentation.py
View file @
8b641b13
...
@@ -12,7 +12,6 @@
...
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# Lint as: python3
"""RetinaNet configuration definition."""
"""RetinaNet configuration definition."""
import
dataclasses
import
dataclasses
from
typing
import
Optional
from
typing
import
Optional
...
...
official/projects/qat/vision/configs/semantic_segmentation_test.py
View file @
8b641b13
...
@@ -17,11 +17,11 @@
...
@@ -17,11 +17,11 @@
from
absl.testing
import
parameterized
from
absl.testing
import
parameterized
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official
import
vision
from
official.core
import
config_definitions
as
cfg
from
official.core
import
config_definitions
as
cfg
from
official.core
import
exp_factory
from
official.core
import
exp_factory
from
official.projects.qat.vision.configs
import
common
from
official.projects.qat.vision.configs
import
common
from
official.projects.qat.vision.configs
import
semantic_segmentation
as
qat_exp_cfg
from
official.projects.qat.vision.configs
import
semantic_segmentation
as
qat_exp_cfg
from
official.vision
import
beta
from
official.vision.configs
import
semantic_segmentation
as
exp_cfg
from
official.vision.configs
import
semantic_segmentation
as
exp_cfg
...
...
official/projects/qat/vision/modeling/__init__.py
View file @
8b641b13
...
@@ -12,7 +12,6 @@
...
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# Lint as: python3
"""Modeling package definition."""
"""Modeling package definition."""
from
official.projects.qat.vision.modeling
import
layers
from
official.projects.qat.vision.modeling
import
layers
official/projects/qat/vision/modeling/layers/__init__.py
View file @
8b641b13
...
@@ -12,7 +12,6 @@
...
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# Lint as: python3
"""Layers package definition."""
"""Layers package definition."""
from
official.projects.qat.vision.modeling.layers.nn_blocks
import
BottleneckBlockQuantized
from
official.projects.qat.vision.modeling.layers.nn_blocks
import
BottleneckBlockQuantized
...
...
official/projects/qat/vision/modeling/layers/nn_blocks.py
View file @
8b641b13
...
@@ -24,42 +24,10 @@ import tensorflow_model_optimization as tfmot
...
@@ -24,42 +24,10 @@ import tensorflow_model_optimization as tfmot
from
official.modeling
import
tf_utils
from
official.modeling
import
tf_utils
from
official.projects.qat.vision.modeling.layers
import
nn_layers
as
qat_nn_layers
from
official.projects.qat.vision.modeling.layers
import
nn_layers
as
qat_nn_layers
from
official.projects.qat.vision.quantization
import
configs
from
official.projects.qat.vision.quantization
import
configs
from
official.projects.qat.vision.quantization
import
helper
from
official.vision.modeling.layers
import
nn_layers
from
official.vision.modeling.layers
import
nn_layers
class
NoOpActivation
:
"""No-op activation which simply returns the incoming tensor.
This activation is required to distinguish between `keras.activations.linear`
which does the same thing. The main difference is that NoOpActivation should
not have any quantize operation applied to it.
"""
def
__call__
(
self
,
x
:
tf
.
Tensor
)
->
tf
.
Tensor
:
return
x
def
get_config
(
self
)
->
Dict
[
str
,
Any
]:
"""Get a config of this object."""
return
{}
def
__eq__
(
self
,
other
:
Any
)
->
bool
:
if
not
other
or
not
isinstance
(
other
,
NoOpActivation
):
return
False
return
True
def
__ne__
(
self
,
other
:
Any
)
->
bool
:
return
not
self
.
__eq__
(
other
)
def
_quantize_wrapped_layer
(
cls
,
quantize_config
):
def
constructor
(
*
arg
,
**
kwargs
):
return
tfmot
.
quantization
.
keras
.
QuantizeWrapperV2
(
cls
(
*
arg
,
**
kwargs
),
quantize_config
)
return
constructor
# This class is copied from modeling.layers.nn_blocks.BottleneckBlock and apply
# This class is copied from modeling.layers.nn_blocks.BottleneckBlock and apply
# QAT.
# QAT.
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
'Vision'
)
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
'Vision'
)
...
@@ -131,17 +99,16 @@ class BottleneckBlockQuantized(tf.keras.layers.Layer):
...
@@ -131,17 +99,16 @@ class BottleneckBlockQuantized(tf.keras.layers.Layer):
self
.
_kernel_regularizer
=
kernel_regularizer
self
.
_kernel_regularizer
=
kernel_regularizer
self
.
_bias_regularizer
=
bias_regularizer
self
.
_bias_regularizer
=
bias_regularizer
if
use_sync_bn
:
if
use_sync_bn
:
self
.
_norm
=
_
quantize_wrapped_layer
(
self
.
_norm
=
helper
.
quantize_wrapped_layer
(
tf
.
keras
.
layers
.
experimental
.
SyncBatchNormalization
,
tf
.
keras
.
layers
.
experimental
.
SyncBatchNormalization
,
configs
.
NoOpQuantizeConfig
())
configs
.
NoOpQuantizeConfig
())
self
.
_norm_with_quantize
=
_
quantize_wrapped_layer
(
self
.
_norm_with_quantize
=
helper
.
quantize_wrapped_layer
(
tf
.
keras
.
layers
.
experimental
.
SyncBatchNormalization
,
tf
.
keras
.
layers
.
experimental
.
SyncBatchNormalization
,
configs
.
Default8BitOutputQuantizeConfig
())
configs
.
Default8BitOutputQuantizeConfig
())
else
:
else
:
self
.
_norm
=
_quantize_wrapped_layer
(
self
.
_norm
=
helper
.
quantize_wrapped_layer
(
tf
.
keras
.
layers
.
BatchNormalization
,
tf
.
keras
.
layers
.
BatchNormalization
,
configs
.
NoOpQuantizeConfig
())
configs
.
NoOpQuantizeConfig
())
self
.
_norm_with_quantize
=
helper
.
quantize_wrapped_layer
(
self
.
_norm_with_quantize
=
_quantize_wrapped_layer
(
tf
.
keras
.
layers
.
BatchNormalization
,
tf
.
keras
.
layers
.
BatchNormalization
,
configs
.
Default8BitOutputQuantizeConfig
())
configs
.
Default8BitOutputQuantizeConfig
())
if
tf
.
keras
.
backend
.
image_data_format
()
==
'channels_last'
:
if
tf
.
keras
.
backend
.
image_data_format
()
==
'channels_last'
:
...
@@ -152,10 +119,10 @@ class BottleneckBlockQuantized(tf.keras.layers.Layer):
...
@@ -152,10 +119,10 @@ class BottleneckBlockQuantized(tf.keras.layers.Layer):
def
build
(
self
,
input_shape
:
Optional
[
Union
[
Sequence
[
int
],
tf
.
Tensor
]]):
def
build
(
self
,
input_shape
:
Optional
[
Union
[
Sequence
[
int
],
tf
.
Tensor
]]):
"""Build variables and child layers to prepare for calling."""
"""Build variables and child layers to prepare for calling."""
conv2d_quantized
=
_
quantize_wrapped_layer
(
conv2d_quantized
=
helper
.
quantize_wrapped_layer
(
tf
.
keras
.
layers
.
Conv2D
,
tf
.
keras
.
layers
.
Conv2D
,
configs
.
Default8BitConvQuantizeConfig
(
configs
.
Default8BitConvQuantizeConfig
(
[
'kernel'
],
[
'activation'
],
[
'kernel'
],
[
'activation'
],
False
))
False
))
if
self
.
_use_projection
:
if
self
.
_use_projection
:
if
self
.
_resnetd_shortcut
:
if
self
.
_resnetd_shortcut
:
self
.
_shortcut0
=
tf
.
keras
.
layers
.
AveragePooling2D
(
self
.
_shortcut0
=
tf
.
keras
.
layers
.
AveragePooling2D
(
...
@@ -168,7 +135,7 @@ class BottleneckBlockQuantized(tf.keras.layers.Layer):
...
@@ -168,7 +135,7 @@ class BottleneckBlockQuantized(tf.keras.layers.Layer):
kernel_initializer
=
self
.
_kernel_initializer
,
kernel_initializer
=
self
.
_kernel_initializer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
bias_regularizer
=
self
.
_bias_regularizer
,
bias_regularizer
=
self
.
_bias_regularizer
,
activation
=
NoOpActivation
())
activation
=
helper
.
NoOpActivation
())
else
:
else
:
self
.
_shortcut
=
conv2d_quantized
(
self
.
_shortcut
=
conv2d_quantized
(
filters
=
self
.
_filters
*
4
,
filters
=
self
.
_filters
*
4
,
...
@@ -178,7 +145,7 @@ class BottleneckBlockQuantized(tf.keras.layers.Layer):
...
@@ -178,7 +145,7 @@ class BottleneckBlockQuantized(tf.keras.layers.Layer):
kernel_initializer
=
self
.
_kernel_initializer
,
kernel_initializer
=
self
.
_kernel_initializer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
bias_regularizer
=
self
.
_bias_regularizer
,
bias_regularizer
=
self
.
_bias_regularizer
,
activation
=
NoOpActivation
())
activation
=
helper
.
NoOpActivation
())
self
.
_norm0
=
self
.
_norm_with_quantize
(
self
.
_norm0
=
self
.
_norm_with_quantize
(
axis
=
self
.
_bn_axis
,
axis
=
self
.
_bn_axis
,
...
@@ -194,7 +161,7 @@ class BottleneckBlockQuantized(tf.keras.layers.Layer):
...
@@ -194,7 +161,7 @@ class BottleneckBlockQuantized(tf.keras.layers.Layer):
kernel_initializer
=
self
.
_kernel_initializer
,
kernel_initializer
=
self
.
_kernel_initializer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
bias_regularizer
=
self
.
_bias_regularizer
,
bias_regularizer
=
self
.
_bias_regularizer
,
activation
=
NoOpActivation
())
activation
=
helper
.
NoOpActivation
())
self
.
_norm1
=
self
.
_norm
(
self
.
_norm1
=
self
.
_norm
(
axis
=
self
.
_bn_axis
,
axis
=
self
.
_bn_axis
,
momentum
=
self
.
_norm_momentum
,
momentum
=
self
.
_norm_momentum
,
...
@@ -214,7 +181,7 @@ class BottleneckBlockQuantized(tf.keras.layers.Layer):
...
@@ -214,7 +181,7 @@ class BottleneckBlockQuantized(tf.keras.layers.Layer):
kernel_initializer
=
self
.
_kernel_initializer
,
kernel_initializer
=
self
.
_kernel_initializer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
bias_regularizer
=
self
.
_bias_regularizer
,
bias_regularizer
=
self
.
_bias_regularizer
,
activation
=
NoOpActivation
())
activation
=
helper
.
NoOpActivation
())
self
.
_norm2
=
self
.
_norm
(
self
.
_norm2
=
self
.
_norm
(
axis
=
self
.
_bn_axis
,
axis
=
self
.
_bn_axis
,
momentum
=
self
.
_norm_momentum
,
momentum
=
self
.
_norm_momentum
,
...
@@ -232,7 +199,7 @@ class BottleneckBlockQuantized(tf.keras.layers.Layer):
...
@@ -232,7 +199,7 @@ class BottleneckBlockQuantized(tf.keras.layers.Layer):
kernel_initializer
=
self
.
_kernel_initializer
,
kernel_initializer
=
self
.
_kernel_initializer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
bias_regularizer
=
self
.
_bias_regularizer
,
bias_regularizer
=
self
.
_bias_regularizer
,
activation
=
NoOpActivation
())
activation
=
helper
.
NoOpActivation
())
self
.
_norm3
=
self
.
_norm_with_quantize
(
self
.
_norm3
=
self
.
_norm_with_quantize
(
axis
=
self
.
_bn_axis
,
axis
=
self
.
_bn_axis
,
momentum
=
self
.
_norm_momentum
,
momentum
=
self
.
_norm_momentum
,
...
@@ -392,10 +359,10 @@ class Conv2DBNBlockQuantized(tf.keras.layers.Layer):
...
@@ -392,10 +359,10 @@ class Conv2DBNBlockQuantized(tf.keras.layers.Layer):
norm_layer
=
(
norm_layer
=
(
tf
.
keras
.
layers
.
experimental
.
SyncBatchNormalization
tf
.
keras
.
layers
.
experimental
.
SyncBatchNormalization
if
use_sync_bn
else
tf
.
keras
.
layers
.
BatchNormalization
)
if
use_sync_bn
else
tf
.
keras
.
layers
.
BatchNormalization
)
self
.
_norm_with_quantize
=
_
quantize_wrapped_layer
(
self
.
_norm_with_quantize
=
helper
.
quantize_wrapped_layer
(
norm_layer
,
configs
.
Default8BitOutputQuantizeConfig
())
norm_layer
,
configs
.
Default8BitOutputQuantizeConfig
())
self
.
_norm
=
_
quantize_wrapped_layer
(
norm_layer
,
self
.
_norm
=
helper
.
quantize_wrapped_layer
(
norm_layer
,
configs
.
NoOpQuantizeConfig
())
configs
.
NoOpQuantizeConfig
())
if
tf
.
keras
.
backend
.
image_data_format
()
==
'channels_last'
:
if
tf
.
keras
.
backend
.
image_data_format
()
==
'channels_last'
:
self
.
_bn_axis
=
-
1
self
.
_bn_axis
=
-
1
...
@@ -432,10 +399,10 @@ class Conv2DBNBlockQuantized(tf.keras.layers.Layer):
...
@@ -432,10 +399,10 @@ class Conv2DBNBlockQuantized(tf.keras.layers.Layer):
if
self
.
_use_explicit_padding
and
self
.
_kernel_size
>
1
:
if
self
.
_use_explicit_padding
and
self
.
_kernel_size
>
1
:
padding_size
=
nn_layers
.
get_padding_for_kernel_size
(
self
.
_kernel_size
)
padding_size
=
nn_layers
.
get_padding_for_kernel_size
(
self
.
_kernel_size
)
self
.
_pad
=
tf
.
keras
.
layers
.
ZeroPadding2D
(
padding_size
)
self
.
_pad
=
tf
.
keras
.
layers
.
ZeroPadding2D
(
padding_size
)
conv2d_quantized
=
_
quantize_wrapped_layer
(
conv2d_quantized
=
helper
.
quantize_wrapped_layer
(
tf
.
keras
.
layers
.
Conv2D
,
tf
.
keras
.
layers
.
Conv2D
,
configs
.
Default8BitConvQuantizeConfig
(
configs
.
Default8BitConvQuantizeConfig
(
[
'kernel'
],
[
'activation'
],
[
'kernel'
],
[
'activation'
],
not
self
.
_use_normalization
))
not
self
.
_use_normalization
))
self
.
_conv0
=
conv2d_quantized
(
self
.
_conv0
=
conv2d_quantized
(
filters
=
self
.
_filters
,
filters
=
self
.
_filters
,
kernel_size
=
self
.
_kernel_size
,
kernel_size
=
self
.
_kernel_size
,
...
@@ -445,7 +412,7 @@ class Conv2DBNBlockQuantized(tf.keras.layers.Layer):
...
@@ -445,7 +412,7 @@ class Conv2DBNBlockQuantized(tf.keras.layers.Layer):
kernel_initializer
=
self
.
_kernel_initializer
,
kernel_initializer
=
self
.
_kernel_initializer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
bias_regularizer
=
self
.
_bias_regularizer
,
bias_regularizer
=
self
.
_bias_regularizer
,
activation
=
NoOpActivation
())
activation
=
helper
.
NoOpActivation
())
if
self
.
_use_normalization
:
if
self
.
_use_normalization
:
self
.
_norm0
=
self
.
_norm_by_activation
(
self
.
_activation
)(
self
.
_norm0
=
self
.
_norm_by_activation
(
self
.
_activation
)(
axis
=
self
.
_bn_axis
,
axis
=
self
.
_bn_axis
,
...
@@ -579,10 +546,10 @@ class InvertedBottleneckBlockQuantized(tf.keras.layers.Layer):
...
@@ -579,10 +546,10 @@ class InvertedBottleneckBlockQuantized(tf.keras.layers.Layer):
norm_layer
=
(
norm_layer
=
(
tf
.
keras
.
layers
.
experimental
.
SyncBatchNormalization
tf
.
keras
.
layers
.
experimental
.
SyncBatchNormalization
if
use_sync_bn
else
tf
.
keras
.
layers
.
BatchNormalization
)
if
use_sync_bn
else
tf
.
keras
.
layers
.
BatchNormalization
)
self
.
_norm_with_quantize
=
_
quantize_wrapped_layer
(
self
.
_norm_with_quantize
=
helper
.
quantize_wrapped_layer
(
norm_layer
,
configs
.
Default8BitOutputQuantizeConfig
())
norm_layer
,
configs
.
Default8BitOutputQuantizeConfig
())
self
.
_norm
=
_
quantize_wrapped_layer
(
norm_layer
,
self
.
_norm
=
helper
.
quantize_wrapped_layer
(
norm_layer
,
configs
.
NoOpQuantizeConfig
())
configs
.
NoOpQuantizeConfig
())
if
tf
.
keras
.
backend
.
image_data_format
()
==
'channels_last'
:
if
tf
.
keras
.
backend
.
image_data_format
()
==
'channels_last'
:
self
.
_bn_axis
=
-
1
self
.
_bn_axis
=
-
1
...
@@ -602,14 +569,14 @@ class InvertedBottleneckBlockQuantized(tf.keras.layers.Layer):
...
@@ -602,14 +569,14 @@ class InvertedBottleneckBlockQuantized(tf.keras.layers.Layer):
def
build
(
self
,
input_shape
:
Optional
[
Union
[
Sequence
[
int
],
tf
.
Tensor
]]):
def
build
(
self
,
input_shape
:
Optional
[
Union
[
Sequence
[
int
],
tf
.
Tensor
]]):
"""Build variables and child layers to prepare for calling."""
"""Build variables and child layers to prepare for calling."""
conv2d_quantized
=
_
quantize_wrapped_layer
(
conv2d_quantized
=
helper
.
quantize_wrapped_layer
(
tf
.
keras
.
layers
.
Conv2D
,
tf
.
keras
.
layers
.
Conv2D
,
configs
.
Default8BitConvQuantizeConfig
(
configs
.
Default8BitConvQuantizeConfig
(
[
'kernel'
],
[
'activation'
],
[
'kernel'
],
[
'activation'
],
False
))
False
))
depthwise_conv2d_quantized
=
_
quantize_wrapped_layer
(
depthwise_conv2d_quantized
=
helper
.
quantize_wrapped_layer
(
tf
.
keras
.
layers
.
DepthwiseConv2D
,
tf
.
keras
.
layers
.
DepthwiseConv2D
,
configs
.
Default8BitConvQuantizeConfig
(
configs
.
Default8BitConvQuantizeConfig
(
[
'depthwise_kernel'
],
[
'depthwise_kernel'
],
[
'activation'
],
False
))
[
'activation'
],
False
))
expand_filters
=
self
.
_in_filters
expand_filters
=
self
.
_in_filters
if
self
.
_expand_ratio
>
1
:
if
self
.
_expand_ratio
>
1
:
# First 1x1 conv for channel expansion.
# First 1x1 conv for channel expansion.
...
@@ -628,7 +595,7 @@ class InvertedBottleneckBlockQuantized(tf.keras.layers.Layer):
...
@@ -628,7 +595,7 @@ class InvertedBottleneckBlockQuantized(tf.keras.layers.Layer):
kernel_initializer
=
self
.
_kernel_initializer
,
kernel_initializer
=
self
.
_kernel_initializer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
bias_regularizer
=
self
.
_bias_regularizer
,
bias_regularizer
=
self
.
_bias_regularizer
,
activation
=
NoOpActivation
())
activation
=
helper
.
NoOpActivation
())
self
.
_norm0
=
self
.
_norm_by_activation
(
self
.
_activation
)(
self
.
_norm0
=
self
.
_norm_by_activation
(
self
.
_activation
)(
axis
=
self
.
_bn_axis
,
axis
=
self
.
_bn_axis
,
momentum
=
self
.
_norm_momentum
,
momentum
=
self
.
_norm_momentum
,
...
@@ -649,7 +616,7 @@ class InvertedBottleneckBlockQuantized(tf.keras.layers.Layer):
...
@@ -649,7 +616,7 @@ class InvertedBottleneckBlockQuantized(tf.keras.layers.Layer):
depthwise_initializer
=
self
.
_kernel_initializer
,
depthwise_initializer
=
self
.
_kernel_initializer
,
depthwise_regularizer
=
self
.
_depthsize_regularizer
,
depthwise_regularizer
=
self
.
_depthsize_regularizer
,
bias_regularizer
=
self
.
_bias_regularizer
,
bias_regularizer
=
self
.
_bias_regularizer
,
activation
=
NoOpActivation
())
activation
=
helper
.
NoOpActivation
())
self
.
_norm1
=
self
.
_norm_by_activation
(
self
.
_depthwise_activation
)(
self
.
_norm1
=
self
.
_norm_by_activation
(
self
.
_depthwise_activation
)(
axis
=
self
.
_bn_axis
,
axis
=
self
.
_bn_axis
,
momentum
=
self
.
_norm_momentum
,
momentum
=
self
.
_norm_momentum
,
...
@@ -690,7 +657,7 @@ class InvertedBottleneckBlockQuantized(tf.keras.layers.Layer):
...
@@ -690,7 +657,7 @@ class InvertedBottleneckBlockQuantized(tf.keras.layers.Layer):
kernel_initializer
=
self
.
_kernel_initializer
,
kernel_initializer
=
self
.
_kernel_initializer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
bias_regularizer
=
self
.
_bias_regularizer
,
bias_regularizer
=
self
.
_bias_regularizer
,
activation
=
NoOpActivation
())
activation
=
helper
.
NoOpActivation
())
self
.
_norm2
=
self
.
_norm_with_quantize
(
self
.
_norm2
=
self
.
_norm_with_quantize
(
axis
=
self
.
_bn_axis
,
axis
=
self
.
_bn_axis
,
momentum
=
self
.
_norm_momentum
,
momentum
=
self
.
_norm_momentum
,
...
...
official/projects/qat/vision/modeling/layers/nn_blocks_test.py
View file @
8b641b13
...
@@ -12,7 +12,6 @@
...
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# Lint as: python3
"""Tests for nn_blocks."""
"""Tests for nn_blocks."""
from
typing
import
Any
,
Iterable
,
Tuple
from
typing
import
Any
,
Iterable
,
Tuple
...
...
Prev
1
2
3
4
5
6
7
8
9
…
26
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment