"megatron/git@developer.sourcefind.cn:OpenDAS/megatron-lm.git" did not exist on "94a90215c334fea297dc48ba2e927dcc96d5e0b0"
Commit 73b5be67 authored by Allen Wang's avatar Allen Wang Committed by A. Unique TensorFlower
Browse files

Modify `augment.transform` to use `tf.rank` rather than `len(shape)`.

PiperOrigin-RevId: 313259032
parent 986b0825
...@@ -169,7 +169,7 @@ def transform(image: tf.Tensor, transforms) -> tf.Tensor: ...@@ -169,7 +169,7 @@ def transform(image: tf.Tensor, transforms) -> tf.Tensor:
"""Prepares input data for `image_ops.transform`.""" """Prepares input data for `image_ops.transform`."""
original_ndims = tf.rank(image) original_ndims = tf.rank(image)
transforms = tf.convert_to_tensor(transforms, dtype=tf.float32) transforms = tf.convert_to_tensor(transforms, dtype=tf.float32)
if len(tf.shape(transforms)) == 1: if tf.rank(transforms) == 1:
transforms = transforms[None] transforms = transforms[None]
image = to_4d(image) image = to_4d(image)
image = image_ops.transform( image = image_ops.transform(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment