Commit a9322830 authored by Fan Yang's avatar Fan Yang Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 433341765
parent 078b78b3
......@@ -21,7 +21,7 @@ from typing import Optional
from official.core import config_definitions as cfg
from official.core import exp_factory
from official.projects.qat.vision.configs import common
from official.vision.beta.configs import image_classification
from official.vision.configs import image_classification
@dataclasses.dataclass
......
......@@ -22,7 +22,7 @@ from official.core import exp_factory
from official.projects.qat.vision.configs import common
from official.projects.qat.vision.configs import image_classification as qat_exp_cfg
from official.vision import beta
from official.vision.beta.configs import image_classification as exp_cfg
from official.vision.configs import image_classification as exp_cfg
class ImageClassificationConfigTest(tf.test.TestCase, parameterized.TestCase):
......
......@@ -20,8 +20,8 @@ from typing import Optional
from official.core import config_definitions as cfg
from official.core import exp_factory
from official.projects.qat.vision.configs import common
from official.vision.beta.configs import retinanet
from official.vision.beta.configs.google import backbones
from official.vision.configs import retinanet
from official.vision.configs.google import backbones
@dataclasses.dataclass
......
......@@ -22,7 +22,7 @@ from official.core import exp_factory
from official.projects.qat.vision.configs import common
from official.projects.qat.vision.configs import retinanet as qat_exp_cfg
from official.vision import beta
from official.vision.beta.configs import retinanet as exp_cfg
from official.vision.configs import retinanet as exp_cfg
class RetinaNetConfigTest(tf.test.TestCase, parameterized.TestCase):
......
......@@ -20,7 +20,7 @@ from typing import Optional
from official.core import config_definitions as cfg
from official.core import exp_factory
from official.projects.qat.vision.configs import common
from official.vision.beta.configs import semantic_segmentation
from official.vision.configs import semantic_segmentation
@dataclasses.dataclass
......
......@@ -22,7 +22,7 @@ from official.core import exp_factory
from official.projects.qat.vision.configs import common
from official.projects.qat.vision.configs import semantic_segmentation as qat_exp_cfg
from official.vision import beta
from official.vision.beta.configs import semantic_segmentation as exp_cfg
from official.vision.configs import semantic_segmentation as exp_cfg
class SemanticSegmentationConfigTest(tf.test.TestCase, parameterized.TestCase):
......
......@@ -22,12 +22,12 @@ from official.projects.qat.vision.configs import common
from official.projects.qat.vision.modeling import segmentation_model as qat_segmentation_model
from official.projects.qat.vision.n_bit import schemes as n_bit_schemes
from official.projects.qat.vision.quantization import schemes
from official.vision.beta import configs
from official.vision.beta.modeling import classification_model
from official.vision.beta.modeling import retinanet_model
from official.vision.beta.modeling.decoders import aspp
from official.vision.beta.modeling.heads import segmentation_heads
from official.vision.beta.modeling.layers import nn_layers
from official.vision import configs
from official.vision.modeling import classification_model
from official.vision.modeling import retinanet_model
from official.vision.modeling.decoders import aspp
from official.vision.modeling.heads import segmentation_heads
from official.vision.modeling.layers import nn_layers
def build_qat_classification_model(
......
......@@ -21,12 +21,12 @@ import tensorflow as tf
from official.projects.qat.vision.configs import common
from official.projects.qat.vision.modeling import factory as qat_factory
from official.vision.beta.configs import backbones
from official.vision.beta.configs import decoders
from official.vision.beta.configs import image_classification as classification_cfg
from official.vision.beta.configs import retinanet as retinanet_cfg
from official.vision.beta.configs import semantic_segmentation as semantic_segmentation_cfg
from official.vision.beta.modeling import factory
from official.vision.configs import backbones
from official.vision.configs import decoders
from official.vision.configs import image_classification as classification_cfg
from official.vision.configs import retinanet as retinanet_cfg
from official.vision.configs import semantic_segmentation as semantic_segmentation_cfg
from official.vision.modeling import factory
class ClassificationModelBuilderTest(parameterized.TestCase, tf.test.TestCase):
......
......@@ -24,7 +24,7 @@ import tensorflow_model_optimization as tfmot
from official.modeling import tf_utils
from official.projects.qat.vision.modeling.layers import nn_layers as qat_nn_layers
from official.projects.qat.vision.quantization import configs
from official.vision.beta.modeling.layers import nn_layers
from official.vision.modeling.layers import nn_layers
class NoOpActivation:
......
......@@ -22,8 +22,8 @@ import tensorflow_model_optimization as tfmot
from official.modeling import tf_utils
from official.projects.qat.vision.quantization import configs
from official.projects.qat.vision.quantization import helper
from official.vision.beta.modeling.decoders import aspp
from official.vision.beta.modeling.layers import nn_layers
from official.vision.modeling.decoders import aspp
from official.vision.modeling.layers import nn_layers
# Type annotations.
......
......@@ -24,7 +24,7 @@ import tensorflow_model_optimization as tfmot
from official.modeling import tf_utils
from official.projects.qat.vision.n_bit import configs
from official.projects.qat.vision.n_bit import nn_layers as qat_nn_layers
from official.vision.beta.modeling.layers import nn_layers
from official.vision.modeling.layers import nn_layers
class NoOpActivation:
......
......@@ -21,7 +21,7 @@ import tensorflow_model_optimization as tfmot
from official.modeling import tf_utils
from official.projects.qat.vision.n_bit import configs
from official.vision.beta.modeling.layers import nn_layers
from official.vision.modeling.layers import nn_layers
# Type annotations.
States = Dict[str, tf.Tensor]
......
......@@ -19,7 +19,7 @@ import tensorflow as tf
from official.core import task_factory
from official.projects.qat.vision.configs import image_classification as exp_cfg
from official.projects.qat.vision.modeling import factory
from official.vision.beta.tasks import image_classification
from official.vision.tasks import image_classification
@task_factory.register_task_cls(exp_cfg.ImageClassificationTask)
......
......@@ -18,7 +18,7 @@ import tensorflow as tf
from official.core import task_factory
from official.projects.qat.vision.configs import retinanet as exp_cfg
from official.projects.qat.vision.modeling import factory
from official.vision.beta.tasks import retinanet
from official.vision.tasks import retinanet
@task_factory.register_task_cls(exp_cfg.RetinaNetTask)
......
......@@ -23,7 +23,7 @@ from official.core import exp_factory
from official.modeling import optimization
from official.projects.qat.vision.tasks import retinanet
from official.vision import beta
from official.vision.beta.configs import retinanet as exp_cfg
from official.vision.configs import retinanet as exp_cfg
class RetinaNetTaskTest(parameterized.TestCase, tf.test.TestCase):
......
......@@ -18,7 +18,7 @@ import tensorflow as tf
from official.core import task_factory
from official.projects.qat.vision.configs import semantic_segmentation as exp_cfg
from official.projects.qat.vision.modeling import factory
from official.vision.beta.tasks import semantic_segmentation
from official.vision.tasks import semantic_segmentation
@task_factory.register_task_cls(exp_cfg.SemanticSegmentationTask)
......
......@@ -18,7 +18,7 @@ from absl import app
from official.common import flags as tfm_flags
from official.projects.qat.vision import registry_imports # pylint: disable=unused-import
from official.vision.beta import train
from official.vision import train
if __name__ == '__main__':
......
......@@ -15,8 +15,8 @@
"""All necessary imports for registration."""
# pylint: disable=unused-import
from official.common import registry_imports
from official.projects.volumetric_models.configs import semantic_segmentation_3d as semantic_segmentation_3d_cfg
from official.projects.volumetric_models.modeling import backbones
from official.projects.volumetric_models.modeling import decoders
from official.projects.volumetric_models.tasks import semantic_segmentation_3d
from official.vision import registry_imports
......@@ -714,7 +714,8 @@ class InvertedBottleneckBlock(tf.keras.layers.Layer):
'use_depthwise': self._use_depthwise,
'use_residual': self._use_residual,
'norm_momentum': self._norm_momentum,
'norm_epsilon': self._norm_epsilon
'norm_epsilon': self._norm_epsilon,
'output_intermediate_endpoints': self._output_intermediate_endpoints
}
base_config = super(InvertedBottleneckBlock, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment