Commit 5b8d66b2 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Modify `augment.transform` to use `transforms.shape.rank` instead of `tf.rank`.

PiperOrigin-RevId: 313435294
parent a0d4ac79
......@@ -169,7 +169,7 @@ def transform(image: tf.Tensor, transforms) -> tf.Tensor:
"""Prepares input data for `image_ops.transform`."""
original_ndims = tf.rank(image)
transforms = tf.convert_to_tensor(transforms, dtype=tf.float32)
if tf.rank(transforms) == 1:
if transforms.shape.rank == 1:
transforms = transforms[None]
image = to_4d(image)
image = image_ops.transform(
......@@ -989,7 +989,7 @@ class RandAugment(ImageAugment):
# pylint:disable=g-long-lambda
lambda selected_func=func, selected_args=args: selected_func(
image, *selected_args)))
# pylint:enable=g-long-lambda
# pylint:enable=g-long-lambda
image = tf.switch_case(branch_index=op_to_select,
branch_fns=branch_fns,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment