Commit fddab2eb authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 303780351
parent f3024690
......@@ -24,7 +24,7 @@ from __future__ import division
from __future__ import print_function
import math
import tensorflow.compat.v2 as tf
import tensorflow as tf
from typing import Any, Dict, Iterable, List, Optional, Text, Tuple, Union
from tensorflow.python.keras.layers.preprocessing import image_preprocessing as image_ops
......@@ -75,8 +75,7 @@ def from_4d(image: tf.Tensor, ndims: int) -> tf.Tensor:
return tf.reshape(image, new_shape)
def _convert_translation_to_transform(
translations: Iterable[int]) -> tf.Tensor:
def _convert_translation_to_transform(translations) -> tf.Tensor:
"""Converts translations to a projective transform.
The translation matrix looks like this:
......@@ -166,8 +165,7 @@ def _convert_angles_to_transform(
)
def transform(image: tf.Tensor,
transforms: Iterable[float]) -> tf.Tensor:
def transform(image: tf.Tensor, transforms) -> tf.Tensor:
"""Prepares input data for `image_ops.transform`."""
original_ndims = tf.rank(image)
transforms = tf.convert_to_tensor(transforms, dtype=tf.float32)
......@@ -181,8 +179,7 @@ def transform(image: tf.Tensor,
return from_4d(image, original_ndims)
def translate(image: tf.Tensor,
translations: Iterable[int]) -> tf.Tensor:
def translate(image: tf.Tensor, translations) -> tf.Tensor:
"""Translates image(s) by provided vectors.
Args:
......@@ -577,7 +574,7 @@ def unwrap(image: tf.Tensor, replace: int) -> tf.Tensor:
return image
def _randomly_negate_tensor(tensor: tf.Tensor) -> tf.Tensor:
def _randomly_negate_tensor(tensor):
"""With 50% prob turn the tensor negative."""
should_flip = tf.cast(tf.floor(tf.random.uniform([]) + 0.5), tf.bool)
final_tensor = tf.cond(should_flip, lambda: tensor, lambda: -tensor)
......
......@@ -21,7 +21,7 @@ from __future__ import print_function
from absl.testing import parameterized
import tensorflow.compat.v2 as tf
import tensorflow as tf
from official.vision.image_classification import augment
......@@ -133,5 +133,4 @@ class AutoaugmentTest(tf.test.TestCase):
self.assertEqual((224, 224, 3), image.shape)
if __name__ == '__main__':
assert tf.version.VERSION.startswith('2.')
tf.test.main()
......@@ -27,7 +27,7 @@ from typing import Any, Tuple, Text, Optional, Mapping
from absl import app
from absl import flags
from absl import logging
import tensorflow.compat.v2 as tf
import tensorflow as tf
from official.modeling import performance
from official.modeling.hyperparams import params_dict
......@@ -423,5 +423,4 @@ if __name__ == '__main__':
flags.mark_flag_as_required('model_type')
flags.mark_flag_as_required('dataset')
assert tf.version.VERSION.startswith('2.')
app.run(main)
......@@ -30,7 +30,7 @@ from typing import Any, Callable, Iterable, Mapping, MutableMapping, Optional, T
from absl import flags
from absl.testing import parameterized
import tensorflow.compat.v2 as tf
import tensorflow as tf
from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import strategy_combinations
......@@ -313,5 +313,4 @@ class UtilTests(parameterized.TestCase, tf.test.TestCase):
tf.io.gfile.rmtree(model_dir)
if __name__ == '__main__':
assert tf.version.VERSION.startswith('2.')
tf.test.main()
......@@ -23,7 +23,7 @@ import os
from typing import Any, List, Optional, Tuple, Mapping, Union
from absl import logging
from dataclasses import dataclass
import tensorflow.compat.v2 as tf
import tensorflow as tf
import tensorflow_datasets as tfds
from official.modeling.hyperparams import base_config
......
......@@ -21,7 +21,7 @@ from __future__ import print_function
import numpy as np
import tensorflow.compat.v1 as tf1
import tensorflow.compat.v2 as tf
import tensorflow as tf
from typing import Text, Optional
from tensorflow.python.tpu import tpu_function
......
......@@ -30,7 +30,7 @@ from typing import Any, Dict, Optional, Text, Tuple
from absl import logging
from dataclasses import dataclass
import tensorflow.compat.v2 as tf
import tensorflow as tf
from official.modeling import tf_utils
from official.modeling.hyperparams import base_config
......
......@@ -20,7 +20,7 @@ from __future__ import print_function
from typing import Any, List, Mapping
import tensorflow.compat.v2 as tf
import tensorflow as tf
BASE_LEARNING_RATE = 0.1
......
......@@ -18,7 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow.compat.v2 as tf
import tensorflow as tf
from official.vision.image_classification import learning_rate
......@@ -86,5 +86,4 @@ class LearningRateTests(tf.test.TestCase):
if __name__ == '__main__':
assert tf.version.VERSION.startswith('2.')
tf.test.main()
......@@ -88,5 +88,4 @@ class KerasMnistTest(tf.test.TestCase, parameterized.TestCase):
if __name__ == "__main__":
tf.compat.v1.enable_v2_behavior()
tf.test.main()
......@@ -19,7 +19,7 @@ from __future__ import division
from __future__ import print_function
from absl import logging
import tensorflow.compat.v2 as tf
import tensorflow as tf
import tensorflow_addons as tfa
from typing import Any, Dict, Text
......
......@@ -19,7 +19,7 @@ from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
import tensorflow.compat.v2 as tf
import tensorflow as tf
from absl.testing import parameterized
from official.vision.image_classification import optimizer_factory
......@@ -111,5 +111,4 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
if __name__ == '__main__':
assert tf.version.VERSION.startswith('2.')
tf.test.main()
......@@ -19,7 +19,7 @@ from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
import tensorflow.compat.v2 as tf
import tensorflow as tf
from typing import List, Optional, Text, Tuple
from official.vision.image_classification import augment
......
......@@ -18,7 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow.compat.v2 as tf
import tensorflow as tf
from official.modeling import performance
from official.staging.training import grad_utils
......
......@@ -24,7 +24,7 @@ import os
from absl import app
from absl import flags
import tensorflow.compat.v2 as tf
import tensorflow as tf
from official.vision.image_classification.resnet import imagenet_preprocessing
from official.vision.image_classification.resnet import resnet_model
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment