Unverified Commit 708b19eb authored by Matt's avatar Matt Committed by GitHub
Browse files

Stop confusing the TF compiler with ModelOutput objects (#28712)

* Stop confusing the TF compiler with ModelOutput objects

* Stop confusing the TF compiler with ModelOutput objects
parent a638de19
...@@ -1171,7 +1171,7 @@ class TFBlipForConditionalGeneration(TFBlipPreTrainedModel): ...@@ -1171,7 +1171,7 @@ class TFBlipForConditionalGeneration(TFBlipPreTrainedModel):
attention_mask=attention_mask, attention_mask=attention_mask,
encoder_hidden_states=image_embeds, encoder_hidden_states=image_embeds,
labels=labels, labels=labels,
return_dict=return_dict, return_dict=False,
training=training, training=training,
) )
...@@ -1179,12 +1179,19 @@ class TFBlipForConditionalGeneration(TFBlipPreTrainedModel): ...@@ -1179,12 +1179,19 @@ class TFBlipForConditionalGeneration(TFBlipPreTrainedModel):
outputs = (outputs[0], outputs[1], image_embeds, vision_outputs[0]) + vision_outputs[2:] outputs = (outputs[0], outputs[1], image_embeds, vision_outputs[0]) + vision_outputs[2:]
return tuple(output for output in outputs if output is not None) return tuple(output for output in outputs if output is not None)
if outputs.loss is not None and outputs.loss.shape.rank == 0: if labels is not None:
outputs.loss = tf.reshape(outputs.loss, (1,)) loss = outputs[0]
logits = outputs[1]
else:
loss = None
logits = outputs[0]
if loss is not None and loss.shape.rank == 0:
loss = tf.reshape(loss, (1,))
return TFBlipForConditionalGenerationModelOutput( return TFBlipForConditionalGenerationModelOutput(
loss=outputs.loss, loss=loss,
logits=outputs.logits, logits=logits,
image_embeds=image_embeds, image_embeds=image_embeds,
last_hidden_state=vision_outputs.last_hidden_state, last_hidden_state=vision_outputs.last_hidden_state,
hidden_states=vision_outputs.hidden_states, hidden_states=vision_outputs.hidden_states,
......
...@@ -1060,7 +1060,8 @@ class TFBlipTextLMHeadModel(TFBlipTextPreTrainedModel): ...@@ -1060,7 +1060,8 @@ class TFBlipTextLMHeadModel(TFBlipTextPreTrainedModel):
labels = labels[:, 1:] labels = labels[:, 1:]
labels = tf.reshape(labels, (-1,)) labels = tf.reshape(labels, (-1,))
# Keras won't give us label smoothing for sparse CE, so we de-sparsify things here # Keras won't give us label smoothing for sparse CE, so we de-sparsify things here
one_hot_labels = tf.one_hot(labels, depth=self.config.vocab_size, dtype=tf.float32) # Use relu to clamp masked labels at 0 to avoid NaN (we will be zeroing those out later anyway)
one_hot_labels = tf.one_hot(tf.nn.relu(labels), depth=self.config.vocab_size, dtype=tf.float32)
loss_fct = tf.keras.losses.CategoricalCrossentropy(from_logits=True, label_smoothing=0.1, reduction="none") loss_fct = tf.keras.losses.CategoricalCrossentropy(from_logits=True, label_smoothing=0.1, reduction="none")
masked_positions = tf.cast(tf.not_equal(labels, -100), dtype=tf.float32) masked_positions = tf.cast(tf.not_equal(labels, -100), dtype=tf.float32)
lm_loss = loss_fct(one_hot_labels, shifted_prediction_scores) lm_loss = loss_fct(one_hot_labels, shifted_prediction_scores)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment