Unverified Commit 0225b135 authored by Srihari Humbarwadi's avatar Srihari Humbarwadi Committed by GitHub
Browse files

Merge branch 'tensorflow:master' into panoptic-deeplab-modeling

parents 7479dbb8 4c571a3c
......@@ -20,13 +20,13 @@ import tensorflow as tf
# pylint: disable=unused-import
from official.core import input_reader
from official.core import task_factory
from official.vision.beta.modeling import factory_3d
from official.vision.beta.tasks import video_classification
from official.vision.projects.video_ssl.configs import video_ssl as exp_cfg
from official.vision.projects.video_ssl.dataloaders import video_ssl_input
from official.vision.projects.video_ssl.losses import losses
from official.vision.projects.video_ssl.modeling import video_ssl_model
# pylint: disable=unused-import
from official.projects.video_ssl.configs import video_ssl as exp_cfg
from official.projects.video_ssl.dataloaders import video_ssl_input
from official.projects.video_ssl.losses import losses
from official.projects.video_ssl.modeling import video_ssl_model
from official.vision.modeling import factory_3d
from official.vision.tasks import video_classification
# pylint: enable=unused-import
@task_factory.register_task_cls(exp_cfg.VideoSSLPretrainTask)
......
......@@ -22,12 +22,13 @@ import orbit
import tensorflow as tf
# pylint: disable=unused-import
from official import vision
from official.core import exp_factory
from official.core import task_factory
from official.modeling import optimization
from official.vision import beta
from official.vision.beta.dataloaders import tfexample_utils
from official.vision.projects.video_ssl.tasks import pretrain
from official.projects.video_ssl.tasks import pretrain
from official.vision.dataloaders import tfexample_utils
# pylint: enable=unused-import
class VideoClassificationTaskTest(tf.test.TestCase):
......
......@@ -20,16 +20,16 @@ from absl import flags
import gin
# pylint: disable=unused-import
from official.common import registry_imports
from official.common import distribute_utils
from official.common import flags as tfm_flags
from official.core import task_factory
from official.core import train_lib
from official.core import train_utils
from official.modeling import performance
from official.vision.projects.video_ssl.modeling import video_ssl_model
from official.vision.projects.video_ssl.tasks import linear_eval
from official.vision.projects.video_ssl.tasks import pretrain
from official.projects.video_ssl.modeling import video_ssl_model
from official.projects.video_ssl.tasks import linear_eval
from official.projects.video_ssl.tasks.google import pretrain
from official.vision import registry_imports
# pylint: disable=unused-import
FLAGS = flags.FLAGS
......
......@@ -22,10 +22,15 @@ from official.vision.beta.modeling.layers.nn_layers import StochasticDepth
class TransformerEncoderBlock(modeling.layers.TransformerEncoderBlock):
"""TransformerEncoderBlock layer with stochastic depth."""
def __init__(self, *args, stochastic_depth_drop_rate=0.0, **kwargs):
def __init__(self,
*args,
stochastic_depth_drop_rate=0.0,
return_attention=False,
**kwargs):
"""Initializes TransformerEncoderBlock."""
super().__init__(*args, **kwargs)
self._stochastic_depth_drop_rate = stochastic_depth_drop_rate
self._return_attention = return_attention
def build(self, input_shape):
if self._stochastic_depth_drop_rate:
......@@ -73,8 +78,9 @@ class TransformerEncoderBlock(modeling.layers.TransformerEncoderBlock):
if key_value is None:
key_value = input_tensor
attention_output = self._attention_layer(
query=target_tensor, value=key_value, attention_mask=attention_mask)
attention_output, attention_scores = self._attention_layer(
query=target_tensor, value=key_value, attention_mask=attention_mask,
return_attention_scores=True)
attention_output = self._attention_dropout(attention_output)
if self._norm_first:
......@@ -95,12 +101,19 @@ class TransformerEncoderBlock(modeling.layers.TransformerEncoderBlock):
layer_output = self._output_dropout(layer_output)
if self._norm_first:
return source_attention_output + self._stochastic_depth(
layer_output, training=training)
if self._return_attention:
return source_attention_output + self._stochastic_depth(
layer_output, training=training), attention_scores
else:
return source_attention_output + self._stochastic_depth(
layer_output, training=training)
# During mixed precision training, layer norm output is always fp32 for now.
# Casts fp32 for the subsequent add.
layer_output = tf.cast(layer_output, tf.float32)
return self._output_layer_norm(
layer_output +
self._stochastic_depth(attention_output, training=training))
if self._return_attention:
return self._output_layer_norm(layer_output + self._stochastic_depth(
attention_output, training=training)), attention_scores
else:
return self._output_layer_norm(layer_output + self._stochastic_depth(
attention_output, training=training))
......@@ -23,7 +23,7 @@ from official.modeling import hyperparams
from official.modeling import optimization
from official.projects.volumetric_models.configs import backbones
from official.projects.volumetric_models.configs import decoders
from official.vision.beta.configs import common
from official.vision.configs import common
@dataclasses.dataclass
......
......@@ -16,8 +16,8 @@
from typing import Any, Dict, Sequence, Tuple
import tensorflow as tf
from official.vision.beta.dataloaders import decoder
from official.vision.beta.dataloaders import parser
from official.vision.dataloaders import decoder
from official.vision.dataloaders import parser
class Decoder(decoder.Decoder):
......
......@@ -20,7 +20,7 @@ from absl.testing import parameterized
import tensorflow as tf
from official.projects.volumetric_models.dataloaders import segmentation_input_3d
from official.vision.beta.dataloaders import tfexample_utils
from official.vision.dataloaders import tfexample_utils
class InputReaderTest(parameterized.TestCase, tf.test.TestCase):
......
......@@ -25,7 +25,7 @@ from typing import Any, Mapping, Sequence
import tensorflow as tf
from official.modeling import hyperparams
from official.projects.volumetric_models.modeling import nn_blocks_3d
from official.vision.beta.modeling.backbones import factory
from official.vision.modeling.backbones import factory
layers = tf.keras.layers
......
......@@ -21,8 +21,8 @@ import tensorflow as tf
from official.modeling import hyperparams
from official.projects.volumetric_models.modeling.decoders import factory as decoder_factory
from official.projects.volumetric_models.modeling.heads import segmentation_heads_3d
from official.vision.beta.modeling import segmentation_model
from official.vision.beta.modeling.backbones import factory as backbone_factory
from official.vision.modeling import segmentation_model
from official.vision.modeling.backbones import factory as backbone_factory
def build_segmentation_model_3d(
......
......@@ -20,7 +20,7 @@ from typing import Sequence, Union
import tensorflow as tf
from official.modeling import tf_utils
from official.vision.beta.modeling.layers import nn_layers
from official.vision.modeling.layers import nn_layers
@tf.keras.utils.register_keras_serializable(package='Vision')
......
......@@ -21,7 +21,7 @@ import tensorflow as tf
from official.projects.volumetric_models.modeling import backbones
from official.projects.volumetric_models.modeling import decoders
from official.projects.volumetric_models.modeling.heads import segmentation_heads_3d
from official.vision.beta.modeling import segmentation_model
from official.vision.modeling import segmentation_model
class SegmentationNetworkUNet3DTest(parameterized.TestCase, tf.test.TestCase):
......
......@@ -43,7 +43,7 @@ from official.common import registry_imports # pylint: disable=unused-import
from official.core import exp_factory
from official.modeling import hyperparams
from official.projects.volumetric_models.serving import semantic_segmentation_3d
from official.vision.beta.serving import export_saved_model_lib
from official.vision.serving import export_saved_model_lib
FLAGS = flags.FLAGS
......
......@@ -22,7 +22,7 @@ import tensorflow as tf
from official.projects.volumetric_models.modeling import backbones
from official.projects.volumetric_models.modeling import decoders
from official.projects.volumetric_models.modeling import factory
from official.vision.beta.serving import export_base
from official.vision.serving import export_base
class SegmentationModule(export_base.ExportModule):
......
......@@ -30,7 +30,7 @@ from official.projects.volumetric_models.evaluation import segmentation_metrics
from official.projects.volumetric_models.modeling import backbones
from official.projects.volumetric_models.modeling import decoders
from official.projects.volumetric_models.tasks import semantic_segmentation_3d as img_seg_task
from official.vision.beta.dataloaders import tfexample_utils
from official.vision.dataloaders import tfexample_utils
class SemanticSegmentationTaskTest(tf.test.TestCase, parameterized.TestCase):
......
......@@ -19,7 +19,7 @@ import gin # pylint: disable=unused-import
from official.common import flags as tfm_flags
from official.projects.volumetric_models import registry_imports # pylint: disable=unused-import
from official.vision.beta import train
from official.vision import train
def main(_):
......
......@@ -21,7 +21,7 @@ from absl import logging
from absl.testing import flagsaver
import tensorflow as tf
from official.projects.volumetric_models import train as train_lib
from official.vision.beta.dataloaders import tfexample_utils
from official.vision.dataloaders import tfexample_utils
FLAGS = flags.FLAGS
......
......@@ -21,7 +21,7 @@ from official.core import config_definitions as cfg
from official.core import exp_factory
from official.modeling import hyperparams
from official.modeling import optimization
from official.vision.beta.configs import common
from official.vision.configs import common
FLAGS = flags.FLAGS
......
......@@ -27,9 +27,9 @@ from typing import Dict
import tensorflow as tf
from official.projects.yt8m.dataloaders import utils
from official.vision.beta.configs import video_classification as exp_cfg
from official.vision.beta.dataloaders import decoder
from official.vision.beta.dataloaders import parser
from official.vision.configs import video_classification as exp_cfg
from official.vision.dataloaders import decoder
from official.vision.dataloaders import parser
def resize_axis(tensor, axis, new_size, fill_value=0):
......
......@@ -21,7 +21,7 @@ from official.common import flags as tfm_flags
from official.projects.yt8m.configs import yt8m
from official.projects.yt8m.tasks import yt8m_task
# pylint: enable=unused-import
from official.vision.beta import train
from official.vision import train
if __name__ == '__main__':
......
......@@ -22,7 +22,7 @@ from absl.testing import flagsaver
import numpy as np
import tensorflow as tf
from official.projects.yt8m import train as train_lib
from official.vision.beta.dataloaders import tfexample_utils
from official.vision.dataloaders import tfexample_utils
FLAGS = flags.FLAGS
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment