Commit 749baad3 authored by Allen Wang's avatar Allen Wang Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 304649192
parent d3320242
......@@ -25,7 +25,7 @@ from __future__ import print_function
import math
import tensorflow as tf
from typing import Any, Dict, Iterable, List, Optional, Text, Tuple, Union
from typing import Any, Dict, List, Optional, Text, Tuple
from tensorflow.python.keras.layers.preprocessing import image_preprocessing as image_ops
......@@ -66,7 +66,7 @@ def to_4d(image: tf.Tensor) -> tf.Tensor:
return tf.reshape(image, new_shape)
def from_4d(image: tf.Tensor, ndims: int) -> tf.Tensor:
def from_4d(image: tf.Tensor, ndims: tf.Tensor) -> tf.Tensor:
"""Converts a 4D image back to `ndims` rank."""
shape = tf.shape(image)
begin = tf.cast(tf.less_equal(ndims, 3), dtype=tf.int32)
......@@ -75,7 +75,7 @@ def from_4d(image: tf.Tensor, ndims: int) -> tf.Tensor:
return tf.reshape(image, new_shape)
def _convert_translation_to_transform(translations) -> tf.Tensor:
def _convert_translation_to_transform(translations: tf.Tensor) -> tf.Tensor:
"""Converts translations to a projective transform.
The translation matrix looks like this:
......@@ -121,9 +121,9 @@ def _convert_translation_to_transform(translations) -> tf.Tensor:
def _convert_angles_to_transform(
angles: Union[Iterable[float], float],
image_width: int,
image_height: int) -> tf.Tensor:
angles: tf.Tensor,
image_width: tf.Tensor,
image_height: tf.Tensor) -> tf.Tensor:
"""Converts an angle or angles to a projective transform.
Args:
......@@ -209,7 +209,7 @@ def rotate(image: tf.Tensor, degrees: float) -> tf.Tensor:
"""
# Convert from degrees to radians.
degrees_to_radians = math.pi / 180.0
radians = degrees * degrees_to_radians
radians = tf.cast(degrees * degrees_to_radians, tf.float32)
original_ndims = tf.rank(image)
image = to_4d(image)
......
......@@ -52,14 +52,21 @@ class TransformsTest(parameterized.TestCase, tf.test.TestCase):
self.assertAllEqual(augment.transform(image, transforms=[1]*8),
[[4, 4], [4, 4]])
def disable_test_translate(self, dtype):
def test_translate(self, dtype):
image = tf.constant(
[[1, 0, 1, 0], [0, 1, 0, 1], [1, 0, 1, 0], [0, 1, 0, 1]],
[[1, 0, 1, 0],
[0, 1, 0, 1],
[1, 0, 1, 0],
[0, 1, 0, 1]],
dtype=dtype)
translations = [-1, -1]
translated = augment.translate(image=image,
translations=translations)
expected = [[1, 0, 1, 0], [0, 1, 0, 0], [1, 0, 1, 0], [0, 0, 0, 0]]
expected = [
[1, 0, 1, 1],
[0, 1, 0, 0],
[1, 0, 1, 1],
[1, 0, 1, 1]]
self.assertAllEqual(translated, expected)
def test_translate_shapes(self, dtype):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment