"vscode:/vscode.git/clone" did not exist on "76bf1e8a4400b90ee3f3dbaa7cf984f1dfc780c9"
Commit ea36aff1 authored by Smit Hinsu's avatar Smit Hinsu Committed by A. Unique TensorFlower
Browse files

Replace dynamic Squeeze with reshape to 1d

TPUs don't support dynamic Squeeze as it won't be able to identify the dimensions to remove if the input shape is dynamic. For example, [<=16, 1] could return a vector or scalar result depending on the size of the dynamic dim.

PiperOrigin-RevId: 415083667
parent f3105295
...@@ -76,7 +76,7 @@ def get_loss_fn(num_classes): ...@@ -76,7 +76,7 @@ def get_loss_fn(num_classes):
def classification_loss_fn(labels, logits): def classification_loss_fn(labels, logits):
"""Classification loss.""" """Classification loss."""
labels = tf.squeeze(labels) labels = tf.reshape(labels, [-1])
log_probs = tf.nn.log_softmax(logits, axis=-1) log_probs = tf.nn.log_softmax(logits, axis=-1)
one_hot_labels = tf.one_hot( one_hot_labels = tf.one_hot(
tf.cast(labels, dtype=tf.int32), depth=num_classes, dtype=tf.float32) tf.cast(labels, dtype=tf.int32), depth=num_classes, dtype=tf.float32)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment