Unverified Commit 0225b135 authored by Srihari Humbarwadi's avatar Srihari Humbarwadi Committed by GitHub
Browse files

Merge branch 'tensorflow:master' into panoptic-deeplab-modeling

parents 7479dbb8 4c571a3c
...@@ -20,13 +20,13 @@ import tensorflow as tf ...@@ -20,13 +20,13 @@ import tensorflow as tf
# pylint: disable=unused-import # pylint: disable=unused-import
from official.core import input_reader from official.core import input_reader
from official.core import task_factory from official.core import task_factory
from official.vision.beta.modeling import factory_3d from official.projects.video_ssl.configs import video_ssl as exp_cfg
from official.vision.beta.tasks import video_classification from official.projects.video_ssl.dataloaders import video_ssl_input
from official.vision.projects.video_ssl.configs import video_ssl as exp_cfg from official.projects.video_ssl.losses import losses
from official.vision.projects.video_ssl.dataloaders import video_ssl_input from official.projects.video_ssl.modeling import video_ssl_model
from official.vision.projects.video_ssl.losses import losses from official.vision.modeling import factory_3d
from official.vision.projects.video_ssl.modeling import video_ssl_model from official.vision.tasks import video_classification
# pylint: disable=unused-import # pylint: enable=unused-import
@task_factory.register_task_cls(exp_cfg.VideoSSLPretrainTask) @task_factory.register_task_cls(exp_cfg.VideoSSLPretrainTask)
......
...@@ -22,12 +22,13 @@ import orbit ...@@ -22,12 +22,13 @@ import orbit
import tensorflow as tf import tensorflow as tf
# pylint: disable=unused-import # pylint: disable=unused-import
from official import vision
from official.core import exp_factory from official.core import exp_factory
from official.core import task_factory from official.core import task_factory
from official.modeling import optimization from official.modeling import optimization
from official.vision import beta from official.projects.video_ssl.tasks import pretrain
from official.vision.beta.dataloaders import tfexample_utils from official.vision.dataloaders import tfexample_utils
from official.vision.projects.video_ssl.tasks import pretrain # pylint: enable=unused-import
class VideoClassificationTaskTest(tf.test.TestCase): class VideoClassificationTaskTest(tf.test.TestCase):
......
...@@ -20,16 +20,16 @@ from absl import flags ...@@ -20,16 +20,16 @@ from absl import flags
import gin import gin
# pylint: disable=unused-import # pylint: disable=unused-import
from official.common import registry_imports
from official.common import distribute_utils from official.common import distribute_utils
from official.common import flags as tfm_flags from official.common import flags as tfm_flags
from official.core import task_factory from official.core import task_factory
from official.core import train_lib from official.core import train_lib
from official.core import train_utils from official.core import train_utils
from official.modeling import performance from official.modeling import performance
from official.vision.projects.video_ssl.modeling import video_ssl_model from official.projects.video_ssl.modeling import video_ssl_model
from official.vision.projects.video_ssl.tasks import linear_eval from official.projects.video_ssl.tasks import linear_eval
from official.vision.projects.video_ssl.tasks import pretrain from official.projects.video_ssl.tasks.google import pretrain
from official.vision import registry_imports
# pylint: disable=unused-import # pylint: disable=unused-import
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
......
...@@ -22,10 +22,15 @@ from official.vision.beta.modeling.layers.nn_layers import StochasticDepth ...@@ -22,10 +22,15 @@ from official.vision.beta.modeling.layers.nn_layers import StochasticDepth
class TransformerEncoderBlock(modeling.layers.TransformerEncoderBlock): class TransformerEncoderBlock(modeling.layers.TransformerEncoderBlock):
"""TransformerEncoderBlock layer with stochastic depth.""" """TransformerEncoderBlock layer with stochastic depth."""
def __init__(self, *args, stochastic_depth_drop_rate=0.0, **kwargs): def __init__(self,
*args,
stochastic_depth_drop_rate=0.0,
return_attention=False,
**kwargs):
"""Initializes TransformerEncoderBlock.""" """Initializes TransformerEncoderBlock."""
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self._stochastic_depth_drop_rate = stochastic_depth_drop_rate self._stochastic_depth_drop_rate = stochastic_depth_drop_rate
self._return_attention = return_attention
def build(self, input_shape): def build(self, input_shape):
if self._stochastic_depth_drop_rate: if self._stochastic_depth_drop_rate:
...@@ -73,8 +78,9 @@ class TransformerEncoderBlock(modeling.layers.TransformerEncoderBlock): ...@@ -73,8 +78,9 @@ class TransformerEncoderBlock(modeling.layers.TransformerEncoderBlock):
if key_value is None: if key_value is None:
key_value = input_tensor key_value = input_tensor
attention_output = self._attention_layer( attention_output, attention_scores = self._attention_layer(
query=target_tensor, value=key_value, attention_mask=attention_mask) query=target_tensor, value=key_value, attention_mask=attention_mask,
return_attention_scores=True)
attention_output = self._attention_dropout(attention_output) attention_output = self._attention_dropout(attention_output)
if self._norm_first: if self._norm_first:
...@@ -95,12 +101,19 @@ class TransformerEncoderBlock(modeling.layers.TransformerEncoderBlock): ...@@ -95,12 +101,19 @@ class TransformerEncoderBlock(modeling.layers.TransformerEncoderBlock):
layer_output = self._output_dropout(layer_output) layer_output = self._output_dropout(layer_output)
if self._norm_first: if self._norm_first:
if self._return_attention:
return source_attention_output + self._stochastic_depth(
layer_output, training=training), attention_scores
else:
return source_attention_output + self._stochastic_depth( return source_attention_output + self._stochastic_depth(
layer_output, training=training) layer_output, training=training)
# During mixed precision training, layer norm output is always fp32 for now. # During mixed precision training, layer norm output is always fp32 for now.
# Casts fp32 for the subsequent add. # Casts fp32 for the subsequent add.
layer_output = tf.cast(layer_output, tf.float32) layer_output = tf.cast(layer_output, tf.float32)
return self._output_layer_norm( if self._return_attention:
layer_output + return self._output_layer_norm(layer_output + self._stochastic_depth(
self._stochastic_depth(attention_output, training=training)) attention_output, training=training)), attention_scores
else:
return self._output_layer_norm(layer_output + self._stochastic_depth(
attention_output, training=training))
...@@ -23,7 +23,7 @@ from official.modeling import hyperparams ...@@ -23,7 +23,7 @@ from official.modeling import hyperparams
from official.modeling import optimization from official.modeling import optimization
from official.projects.volumetric_models.configs import backbones from official.projects.volumetric_models.configs import backbones
from official.projects.volumetric_models.configs import decoders from official.projects.volumetric_models.configs import decoders
from official.vision.beta.configs import common from official.vision.configs import common
@dataclasses.dataclass @dataclasses.dataclass
......
...@@ -16,8 +16,8 @@ ...@@ -16,8 +16,8 @@
from typing import Any, Dict, Sequence, Tuple from typing import Any, Dict, Sequence, Tuple
import tensorflow as tf import tensorflow as tf
from official.vision.beta.dataloaders import decoder from official.vision.dataloaders import decoder
from official.vision.beta.dataloaders import parser from official.vision.dataloaders import parser
class Decoder(decoder.Decoder): class Decoder(decoder.Decoder):
......
...@@ -20,7 +20,7 @@ from absl.testing import parameterized ...@@ -20,7 +20,7 @@ from absl.testing import parameterized
import tensorflow as tf import tensorflow as tf
from official.projects.volumetric_models.dataloaders import segmentation_input_3d from official.projects.volumetric_models.dataloaders import segmentation_input_3d
from official.vision.beta.dataloaders import tfexample_utils from official.vision.dataloaders import tfexample_utils
class InputReaderTest(parameterized.TestCase, tf.test.TestCase): class InputReaderTest(parameterized.TestCase, tf.test.TestCase):
......
...@@ -25,7 +25,7 @@ from typing import Any, Mapping, Sequence ...@@ -25,7 +25,7 @@ from typing import Any, Mapping, Sequence
import tensorflow as tf import tensorflow as tf
from official.modeling import hyperparams from official.modeling import hyperparams
from official.projects.volumetric_models.modeling import nn_blocks_3d from official.projects.volumetric_models.modeling import nn_blocks_3d
from official.vision.beta.modeling.backbones import factory from official.vision.modeling.backbones import factory
layers = tf.keras.layers layers = tf.keras.layers
......
...@@ -21,8 +21,8 @@ import tensorflow as tf ...@@ -21,8 +21,8 @@ import tensorflow as tf
from official.modeling import hyperparams from official.modeling import hyperparams
from official.projects.volumetric_models.modeling.decoders import factory as decoder_factory from official.projects.volumetric_models.modeling.decoders import factory as decoder_factory
from official.projects.volumetric_models.modeling.heads import segmentation_heads_3d from official.projects.volumetric_models.modeling.heads import segmentation_heads_3d
from official.vision.beta.modeling import segmentation_model from official.vision.modeling import segmentation_model
from official.vision.beta.modeling.backbones import factory as backbone_factory from official.vision.modeling.backbones import factory as backbone_factory
def build_segmentation_model_3d( def build_segmentation_model_3d(
......
...@@ -20,7 +20,7 @@ from typing import Sequence, Union ...@@ -20,7 +20,7 @@ from typing import Sequence, Union
import tensorflow as tf import tensorflow as tf
from official.modeling import tf_utils from official.modeling import tf_utils
from official.vision.beta.modeling.layers import nn_layers from official.vision.modeling.layers import nn_layers
@tf.keras.utils.register_keras_serializable(package='Vision') @tf.keras.utils.register_keras_serializable(package='Vision')
......
...@@ -21,7 +21,7 @@ import tensorflow as tf ...@@ -21,7 +21,7 @@ import tensorflow as tf
from official.projects.volumetric_models.modeling import backbones from official.projects.volumetric_models.modeling import backbones
from official.projects.volumetric_models.modeling import decoders from official.projects.volumetric_models.modeling import decoders
from official.projects.volumetric_models.modeling.heads import segmentation_heads_3d from official.projects.volumetric_models.modeling.heads import segmentation_heads_3d
from official.vision.beta.modeling import segmentation_model from official.vision.modeling import segmentation_model
class SegmentationNetworkUNet3DTest(parameterized.TestCase, tf.test.TestCase): class SegmentationNetworkUNet3DTest(parameterized.TestCase, tf.test.TestCase):
......
...@@ -43,7 +43,7 @@ from official.common import registry_imports # pylint: disable=unused-import ...@@ -43,7 +43,7 @@ from official.common import registry_imports # pylint: disable=unused-import
from official.core import exp_factory from official.core import exp_factory
from official.modeling import hyperparams from official.modeling import hyperparams
from official.projects.volumetric_models.serving import semantic_segmentation_3d from official.projects.volumetric_models.serving import semantic_segmentation_3d
from official.vision.beta.serving import export_saved_model_lib from official.vision.serving import export_saved_model_lib
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
......
...@@ -22,7 +22,7 @@ import tensorflow as tf ...@@ -22,7 +22,7 @@ import tensorflow as tf
from official.projects.volumetric_models.modeling import backbones from official.projects.volumetric_models.modeling import backbones
from official.projects.volumetric_models.modeling import decoders from official.projects.volumetric_models.modeling import decoders
from official.projects.volumetric_models.modeling import factory from official.projects.volumetric_models.modeling import factory
from official.vision.beta.serving import export_base from official.vision.serving import export_base
class SegmentationModule(export_base.ExportModule): class SegmentationModule(export_base.ExportModule):
......
...@@ -30,7 +30,7 @@ from official.projects.volumetric_models.evaluation import segmentation_metrics ...@@ -30,7 +30,7 @@ from official.projects.volumetric_models.evaluation import segmentation_metrics
from official.projects.volumetric_models.modeling import backbones from official.projects.volumetric_models.modeling import backbones
from official.projects.volumetric_models.modeling import decoders from official.projects.volumetric_models.modeling import decoders
from official.projects.volumetric_models.tasks import semantic_segmentation_3d as img_seg_task from official.projects.volumetric_models.tasks import semantic_segmentation_3d as img_seg_task
from official.vision.beta.dataloaders import tfexample_utils from official.vision.dataloaders import tfexample_utils
class SemanticSegmentationTaskTest(tf.test.TestCase, parameterized.TestCase): class SemanticSegmentationTaskTest(tf.test.TestCase, parameterized.TestCase):
......
...@@ -19,7 +19,7 @@ import gin # pylint: disable=unused-import ...@@ -19,7 +19,7 @@ import gin # pylint: disable=unused-import
from official.common import flags as tfm_flags from official.common import flags as tfm_flags
from official.projects.volumetric_models import registry_imports # pylint: disable=unused-import from official.projects.volumetric_models import registry_imports # pylint: disable=unused-import
from official.vision.beta import train from official.vision import train
def main(_): def main(_):
......
...@@ -21,7 +21,7 @@ from absl import logging ...@@ -21,7 +21,7 @@ from absl import logging
from absl.testing import flagsaver from absl.testing import flagsaver
import tensorflow as tf import tensorflow as tf
from official.projects.volumetric_models import train as train_lib from official.projects.volumetric_models import train as train_lib
from official.vision.beta.dataloaders import tfexample_utils from official.vision.dataloaders import tfexample_utils
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
......
...@@ -21,7 +21,7 @@ from official.core import config_definitions as cfg ...@@ -21,7 +21,7 @@ from official.core import config_definitions as cfg
from official.core import exp_factory from official.core import exp_factory
from official.modeling import hyperparams from official.modeling import hyperparams
from official.modeling import optimization from official.modeling import optimization
from official.vision.beta.configs import common from official.vision.configs import common
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
......
...@@ -27,9 +27,9 @@ from typing import Dict ...@@ -27,9 +27,9 @@ from typing import Dict
import tensorflow as tf import tensorflow as tf
from official.projects.yt8m.dataloaders import utils from official.projects.yt8m.dataloaders import utils
from official.vision.beta.configs import video_classification as exp_cfg from official.vision.configs import video_classification as exp_cfg
from official.vision.beta.dataloaders import decoder from official.vision.dataloaders import decoder
from official.vision.beta.dataloaders import parser from official.vision.dataloaders import parser
def resize_axis(tensor, axis, new_size, fill_value=0): def resize_axis(tensor, axis, new_size, fill_value=0):
......
...@@ -21,7 +21,7 @@ from official.common import flags as tfm_flags ...@@ -21,7 +21,7 @@ from official.common import flags as tfm_flags
from official.projects.yt8m.configs import yt8m from official.projects.yt8m.configs import yt8m
from official.projects.yt8m.tasks import yt8m_task from official.projects.yt8m.tasks import yt8m_task
# pylint: enable=unused-import # pylint: enable=unused-import
from official.vision.beta import train from official.vision import train
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -22,7 +22,7 @@ from absl.testing import flagsaver ...@@ -22,7 +22,7 @@ from absl.testing import flagsaver
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from official.projects.yt8m import train as train_lib from official.projects.yt8m import train as train_lib
from official.vision.beta.dataloaders import tfexample_utils from official.vision.dataloaders import tfexample_utils
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment