Unverified Commit 0225b135 authored by Srihari Humbarwadi's avatar Srihari Humbarwadi Committed by GitHub
Browse files

Merge branch 'tensorflow:master' into panoptic-deeplab-modeling

parents 7479dbb8 4c571a3c
...@@ -33,6 +33,8 @@ LayerPattern = tfmot.quantization.keras.graph_transformations.transforms.LayerPa ...@@ -33,6 +33,8 @@ LayerPattern = tfmot.quantization.keras.graph_transformations.transforms.LayerPa
_QUANTIZATION_WEIGHT_NAMES = [ _QUANTIZATION_WEIGHT_NAMES = [
'output_max', 'output_min', 'optimizer_step', 'output_max', 'output_min', 'optimizer_step',
'kernel_min', 'kernel_max', 'kernel_min', 'kernel_max',
'add_three_min', 'add_three_max',
'divide_six_min', 'divide_six_max',
'depthwise_kernel_min', 'depthwise_kernel_max', 'depthwise_kernel_min', 'depthwise_kernel_max',
'reduce_mean_quantizer_vars_min', 'reduce_mean_quantizer_vars_max'] 'reduce_mean_quantizer_vars_min', 'reduce_mean_quantizer_vars_max']
......
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for semantic segmentation task."""
# pylint: disable=unused-import
from absl.testing import parameterized
import orbit
import tensorflow as tf
from official.core import exp_factory
from official.modeling import optimization
from official.projects.qat.vision.tasks import semantic_segmentation
from official.vision import beta
from official.vision.beta.configs import semantic_segmentation as exp_cfg
class SemanticSegmentationTaskTest(parameterized.TestCase, tf.test.TestCase):
@parameterized.parameters(
('mnv2_deeplabv3_pascal_qat', True),
('mnv2_deeplabv3_pascal_qat', False),
)
def test_semantic_segmentation_task(self, test_config, is_training):
"""Semantic segmentation task test for training and val using toy configs."""
config = exp_factory.get_exp_config(test_config)
# modify config to suit local testing
config.task.model.input_size = [512, 512, 3]
config.trainer.steps_per_loop = 1
config.task.train_data.global_batch_size = 1
config.task.validation_data.global_batch_size = 1
config.task.train_data.shuffle_buffer_size = 2
config.task.validation_data.shuffle_buffer_size = 2
config.train_steps = 1
config.task.model.decoder.aspp.output_tensor = True
task = semantic_segmentation.SemanticSegmentationTask(config.task)
model = task.build_model()
metrics = task.build_metrics(training=is_training)
strategy = tf.distribute.get_strategy()
data_config = config.task.train_data if is_training else config.task.validation_data
dataset = orbit.utils.make_distributed_dataset(strategy, task.build_inputs,
data_config)
iterator = iter(dataset)
opt_factory = optimization.OptimizerFactory(config.trainer.optimizer_config)
optimizer = opt_factory.build_optimizer(opt_factory.build_learning_rate())
if is_training:
task.train_step(next(iterator), model, optimizer, metrics=metrics)
else:
task.validation_step(next(iterator), model, metrics=metrics)
if __name__ == '__main__':
tf.test.main()
...@@ -18,8 +18,8 @@ import dataclasses ...@@ -18,8 +18,8 @@ import dataclasses
from typing import Text from typing import Text
from official.modeling import hyperparams from official.modeling import hyperparams
from official.vision.beta.configs import backbones_3d from official.vision.configs import backbones_3d
from official.vision.beta.configs import video_classification from official.vision.configs import video_classification
@dataclasses.dataclass @dataclasses.dataclass
...@@ -97,4 +97,3 @@ class S3DModel(video_classification.VideoClassificationModel): ...@@ -97,4 +97,3 @@ class S3DModel(video_classification.VideoClassificationModel):
""" """
model_type: str = 's3d' model_type: str = 's3d'
backbone: Backbone3D = Backbone3D() backbone: Backbone3D = Backbone3D()
...@@ -19,7 +19,7 @@ from typing import Callable, Dict, Optional, Sequence, Set, Text, Tuple, Type, U ...@@ -19,7 +19,7 @@ from typing import Callable, Dict, Optional, Sequence, Set, Text, Tuple, Type, U
import tensorflow as tf import tensorflow as tf
from official.projects.s3d.modeling import net_utils from official.projects.s3d.modeling import net_utils
from official.vision.beta.modeling.layers import nn_blocks_3d from official.vision.modeling.layers import nn_blocks_3d
INCEPTION_V1_CONV_ENDPOINTS = [ INCEPTION_V1_CONV_ENDPOINTS = [
'Conv2d_1a_7x7', 'Conv2d_2c_3x3', 'Mixed_3b', 'Mixed_3c', 'Mixed_4b', 'Conv2d_1a_7x7', 'Conv2d_2c_3x3', 'Mixed_3b', 'Mixed_3c', 'Mixed_4b',
......
...@@ -26,8 +26,8 @@ from official.modeling import hyperparams ...@@ -26,8 +26,8 @@ from official.modeling import hyperparams
from official.projects.s3d.configs import s3d as cfg from official.projects.s3d.configs import s3d as cfg
from official.projects.s3d.modeling import inception_utils from official.projects.s3d.modeling import inception_utils
from official.projects.s3d.modeling import net_utils from official.projects.s3d.modeling import net_utils
from official.vision.beta.modeling import factory_3d as model_factory from official.vision.modeling import factory_3d as model_factory
from official.vision.beta.modeling.backbones import factory as backbone_factory from official.vision.modeling.backbones import factory as backbone_factory
initializers = tf.keras.initializers initializers = tf.keras.initializers
regularizers = tf.keras.regularizers regularizers = tf.keras.regularizers
......
...@@ -17,16 +17,14 @@ ...@@ -17,16 +17,14 @@
from absl import app from absl import app
# pylint: disable=unused-import
from official.common import registry_imports
# pylint: enable=unused-import
from official.common import flags as tfm_flags from official.common import flags as tfm_flags
# pylint: disable=unused-import # pylint: disable=unused-import
from official.projects.s3d.configs.google import s3d as s3d_config from official.projects.s3d.configs.google import s3d as s3d_config
from official.projects.s3d.modeling import s3d from official.projects.s3d.modeling import s3d
from official.projects.s3d.tasks.google import automl_video_classification from official.projects.s3d.tasks.google import automl_video_classification
from official.vision import registry_imports
# pylint: enable=unused-import # pylint: enable=unused-import
from official.vision.beta import train from official.vision import train
if __name__ == '__main__': if __name__ == '__main__':
tfm_flags.define_flags() tfm_flags.define_flags()
......
...@@ -15,4 +15,4 @@ ...@@ -15,4 +15,4 @@
# Lint as: python3 # Lint as: python3
"""Configs package definition.""" """Configs package definition."""
from official.vision.projects.video_ssl.configs import video_ssl from official.projects.video_ssl.configs import video_ssl
...@@ -20,8 +20,8 @@ import dataclasses ...@@ -20,8 +20,8 @@ import dataclasses
from official.core import config_definitions as cfg from official.core import config_definitions as cfg
from official.core import exp_factory from official.core import exp_factory
from official.vision.beta.configs import common from official.vision.configs import common
from official.vision.beta.configs import video_classification from official.vision.configs import video_classification
Losses = video_classification.Losses Losses = video_classification.Losses
......
...@@ -18,10 +18,10 @@ ...@@ -18,10 +18,10 @@
from absl.testing import parameterized from absl.testing import parameterized
import tensorflow as tf import tensorflow as tf
from official import vision
from official.core import config_definitions as cfg from official.core import config_definitions as cfg
from official.core import exp_factory from official.core import exp_factory
from official.vision import beta from official.projects.video_ssl.configs import video_ssl as exp_cfg
from official.vision.projects.video_ssl.configs import video_ssl as exp_cfg
class VideoClassificationConfigTest(tf.test.TestCase, parameterized.TestCase): class VideoClassificationConfigTest(tf.test.TestCase, parameterized.TestCase):
......
...@@ -19,11 +19,10 @@ from typing import Dict, Optional, Tuple ...@@ -19,11 +19,10 @@ from typing import Dict, Optional, Tuple
from absl import logging from absl import logging
import tensorflow as tf import tensorflow as tf
from official.projects.video_ssl.configs import video_ssl as exp_cfg
from official.vision.beta.dataloaders import video_input from official.projects.video_ssl.ops import video_ssl_preprocess_ops
from official.vision.beta.ops import preprocess_ops_3d from official.vision.dataloaders import video_input
from official.vision.projects.video_ssl.configs import video_ssl as exp_cfg from official.vision.ops import preprocess_ops_3d
from official.vision.projects.video_ssl.ops import video_ssl_preprocess_ops
IMAGE_KEY = 'image/encoded' IMAGE_KEY = 'image/encoded'
LABEL_KEY = 'clip/label/index' LABEL_KEY = 'clip/label/index'
......
...@@ -21,8 +21,8 @@ import numpy as np ...@@ -21,8 +21,8 @@ import numpy as np
from PIL import Image from PIL import Image
import tensorflow as tf import tensorflow as tf
from official.vision.projects.video_ssl.configs import video_ssl as exp_cfg from official.projects.video_ssl.configs import video_ssl as exp_cfg
from official.vision.projects.video_ssl.dataloaders import video_ssl_input from official.projects.video_ssl.dataloaders import video_ssl_input
AUDIO_KEY = 'features/audio' AUDIO_KEY = 'features/audio'
......
...@@ -20,9 +20,9 @@ from typing import Mapping, Optional ...@@ -20,9 +20,9 @@ from typing import Mapping, Optional
import tensorflow as tf import tensorflow as tf
from official.modeling import tf_utils from official.modeling import tf_utils
from official.vision.beta.modeling import backbones from official.projects.video_ssl.configs import video_ssl as video_ssl_cfg
from official.vision.beta.modeling import factory_3d as model_factory from official.vision.modeling import backbones
from official.vision.projects.video_ssl.configs import video_ssl as video_ssl_cfg from official.vision.modeling import factory_3d as model_factory
layers = tf.keras.layers layers = tf.keras.layers
......
...@@ -12,10 +12,9 @@ ...@@ -12,10 +12,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import tensorflow as tf import tensorflow as tf
from official.vision.beta.ops import preprocess_ops_3d from official.projects.video_ssl.ops import video_ssl_preprocess_ops
from official.vision.projects.video_ssl.ops import video_ssl_preprocess_ops from official.vision.ops import preprocess_ops_3d
class VideoSslPreprocessOpsTest(tf.test.TestCase): class VideoSslPreprocessOpsTest(tf.test.TestCase):
......
...@@ -14,5 +14,5 @@ ...@@ -14,5 +14,5 @@
"""Tasks package definition.""" """Tasks package definition."""
from official.vision.projects.video_ssl.tasks import linear_eval from official.projects.video_ssl.tasks import linear_eval
from official.vision.projects.video_ssl.tasks import pretrain from official.projects.video_ssl.tasks import pretrain
...@@ -14,17 +14,15 @@ ...@@ -14,17 +14,15 @@
# Lint as: python3 # Lint as: python3
"""Video ssl linear evaluation task definition.""" """Video ssl linear evaluation task definition."""
from typing import Any, List, Optional, Tuple from typing import Any, Optional, List, Tuple
from absl import logging from absl import logging
import tensorflow as tf import tensorflow as tf
# pylint: disable=unused-import # pylint: disable=unused-import
from official.core import task_factory from official.core import task_factory
from official.vision.beta.tasks.google import video_classification from official.projects.video_ssl.configs.google import video_ssl as exp_cfg
from official.vision.projects.video_ssl.configs import video_ssl as exp_cfg from official.projects.video_ssl.modeling import video_ssl_model
from official.vision.projects.video_ssl.modeling import video_ssl_model from official.vision.tasks import video_classification
# pylint: disable=unused-import
@task_factory.register_task_cls(exp_cfg.VideoSSLEvalTask) @task_factory.register_task_cls(exp_cfg.VideoSSLEvalTask)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment