Commit 5b8d66b2 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Modify `augment.transform` to use `transforms.shape.rank` instead of `tf.rank`.

PiperOrigin-RevId: 313435294
parent a0d4ac79
...@@ -169,7 +169,7 @@ def transform(image: tf.Tensor, transforms) -> tf.Tensor: ...@@ -169,7 +169,7 @@ def transform(image: tf.Tensor, transforms) -> tf.Tensor:
"""Prepares input data for `image_ops.transform`.""" """Prepares input data for `image_ops.transform`."""
original_ndims = tf.rank(image) original_ndims = tf.rank(image)
transforms = tf.convert_to_tensor(transforms, dtype=tf.float32) transforms = tf.convert_to_tensor(transforms, dtype=tf.float32)
if tf.rank(transforms) == 1: if transforms.shape.rank == 1:
transforms = transforms[None] transforms = transforms[None]
image = to_4d(image) image = to_4d(image)
image = image_ops.transform( image = image_ops.transform(
...@@ -989,7 +989,7 @@ class RandAugment(ImageAugment): ...@@ -989,7 +989,7 @@ class RandAugment(ImageAugment):
# pylint:disable=g-long-lambda # pylint:disable=g-long-lambda
lambda selected_func=func, selected_args=args: selected_func( lambda selected_func=func, selected_args=args: selected_func(
image, *selected_args))) image, *selected_args)))
# pylint:enable=g-long-lambda # pylint:enable=g-long-lambda
image = tf.switch_case(branch_index=op_to_select, image = tf.switch_case(branch_index=op_to_select,
branch_fns=branch_fns, branch_fns=branch_fns,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment