Commit 5b8d66b2 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Modify `augment.transform` to use `transforms.shape.rank` instead of `tf.rank`.

PiperOrigin-RevId: 313435294
parent a0d4ac79
...@@ -169,7 +169,7 @@ def transform(image: tf.Tensor, transforms) -> tf.Tensor: ...@@ -169,7 +169,7 @@ def transform(image: tf.Tensor, transforms) -> tf.Tensor:
"""Prepares input data for `image_ops.transform`.""" """Prepares input data for `image_ops.transform`."""
original_ndims = tf.rank(image) original_ndims = tf.rank(image)
transforms = tf.convert_to_tensor(transforms, dtype=tf.float32) transforms = tf.convert_to_tensor(transforms, dtype=tf.float32)
if tf.rank(transforms) == 1: if transforms.shape.rank == 1:
transforms = transforms[None] transforms = transforms[None]
image = to_4d(image) image = to_4d(image)
image = image_ops.transform( image = image_ops.transform(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment