"...resnet50_tensorflow.git" did not exist on "adc01cd76ae0d9d3b2e8dde3ec6bf4086f7da046"
Unverified Commit dc05dd53 authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Fix TF Causal LM models' returned logits (#15256)



* Fix TF Causal LM models' returned logits

* Fix expected shape in the tests
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent af5c3329
...@@ -1542,9 +1542,9 @@ class TFBertLMHeadModel(TFBertPreTrainedModel, TFCausalLanguageModelingLoss): ...@@ -1542,9 +1542,9 @@ class TFBertLMHeadModel(TFBertPreTrainedModel, TFCausalLanguageModelingLoss):
if inputs["labels"] is not None: if inputs["labels"] is not None:
# shift labels to the left and cut last logit token # shift labels to the left and cut last logit token
logits = logits[:, :-1] shifted_logits = logits[:, :-1]
labels = inputs["labels"][:, 1:] labels = inputs["labels"][:, 1:]
loss = self.hf_compute_loss(labels=labels, logits=logits) loss = self.hf_compute_loss(labels=labels, logits=shifted_logits)
if not inputs["return_dict"]: if not inputs["return_dict"]:
output = (logits,) + outputs[2:] output = (logits,) + outputs[2:]
......
...@@ -735,9 +735,9 @@ class TFCTRLLMHeadModel(TFCTRLPreTrainedModel, TFCausalLanguageModelingLoss): ...@@ -735,9 +735,9 @@ class TFCTRLLMHeadModel(TFCTRLPreTrainedModel, TFCausalLanguageModelingLoss):
loss = None loss = None
if inputs["labels"] is not None: if inputs["labels"] is not None:
# shift labels to the left and cut last logit token # shift labels to the left and cut last logit token
logits = logits[:, :-1] shifted_logits = logits[:, :-1]
labels = inputs["labels"][:, 1:] labels = inputs["labels"][:, 1:]
loss = self.hf_compute_loss(labels, logits) loss = self.hf_compute_loss(labels, shifted_logits)
if not inputs["return_dict"]: if not inputs["return_dict"]:
output = (logits,) + transformer_outputs[1:] output = (logits,) + transformer_outputs[1:]
......
...@@ -949,9 +949,9 @@ class TFGPT2LMHeadModel(TFGPT2PreTrainedModel, TFCausalLanguageModelingLoss): ...@@ -949,9 +949,9 @@ class TFGPT2LMHeadModel(TFGPT2PreTrainedModel, TFCausalLanguageModelingLoss):
loss = None loss = None
if inputs["labels"] is not None: if inputs["labels"] is not None:
# shift labels to the left and cut last logit token # shift labels to the left and cut last logit token
logits = logits[:, :-1] shifted_logits = logits[:, :-1]
labels = inputs["labels"][:, 1:] labels = inputs["labels"][:, 1:]
loss = self.hf_compute_loss(labels, logits) loss = self.hf_compute_loss(labels, shifted_logits)
if not inputs["return_dict"]: if not inputs["return_dict"]:
output = (logits,) + transformer_outputs[1:] output = (logits,) + transformer_outputs[1:]
......
...@@ -656,9 +656,9 @@ class TFOpenAIGPTLMHeadModel(TFOpenAIGPTPreTrainedModel, TFCausalLanguageModelin ...@@ -656,9 +656,9 @@ class TFOpenAIGPTLMHeadModel(TFOpenAIGPTPreTrainedModel, TFCausalLanguageModelin
loss = None loss = None
if inputs["labels"] is not None: if inputs["labels"] is not None:
# shift labels to the left and cut last logit token # shift labels to the left and cut last logit token
logits = logits[:, :-1] shifted_logits = logits[:, :-1]
labels = inputs["labels"][:, 1:] labels = inputs["labels"][:, 1:]
loss = self.hf_compute_loss(labels, logits) loss = self.hf_compute_loss(labels, shifted_logits)
if not inputs["return_dict"]: if not inputs["return_dict"]:
output = (logits,) + transformer_outputs[1:] output = (logits,) + transformer_outputs[1:]
......
...@@ -1275,9 +1275,9 @@ class TFRemBertForCausalLM(TFRemBertPreTrainedModel, TFCausalLanguageModelingLos ...@@ -1275,9 +1275,9 @@ class TFRemBertForCausalLM(TFRemBertPreTrainedModel, TFCausalLanguageModelingLos
if inputs["labels"] is not None: if inputs["labels"] is not None:
# shift labels to the left and cut last logit token # shift labels to the left and cut last logit token
logits = logits[:, :-1] shifted_logits = logits[:, :-1]
labels = inputs["labels"][:, 1:] labels = inputs["labels"][:, 1:]
loss = self.hf_compute_loss(labels=labels, logits=logits) loss = self.hf_compute_loss(labels=labels, logits=shifted_logits)
if not inputs["return_dict"]: if not inputs["return_dict"]:
output = (logits,) + outputs[2:] output = (logits,) + outputs[2:]
......
...@@ -1310,9 +1310,9 @@ class TFRobertaForCausalLM(TFRobertaPreTrainedModel, TFCausalLanguageModelingLos ...@@ -1310,9 +1310,9 @@ class TFRobertaForCausalLM(TFRobertaPreTrainedModel, TFCausalLanguageModelingLos
if inputs["labels"] is not None: if inputs["labels"] is not None:
# shift labels to the left and cut last logit token # shift labels to the left and cut last logit token
logits = logits[:, :-1] shifted_logits = logits[:, :-1]
labels = inputs["labels"][:, 1:] labels = inputs["labels"][:, 1:]
loss = self.hf_compute_loss(labels=labels, logits=logits) loss = self.hf_compute_loss(labels=labels, logits=shifted_logits)
if not inputs["return_dict"]: if not inputs["return_dict"]:
output = (logits,) + outputs[2:] output = (logits,) + outputs[2:]
......
...@@ -1035,9 +1035,9 @@ class TFRoFormerForCausalLM(TFRoFormerPreTrainedModel, TFCausalLanguageModelingL ...@@ -1035,9 +1035,9 @@ class TFRoFormerForCausalLM(TFRoFormerPreTrainedModel, TFCausalLanguageModelingL
if inputs["labels"] is not None: if inputs["labels"] is not None:
# shift labels to the left and cut last logit token # shift labels to the left and cut last logit token
logits = logits[:, :-1] shifted_logits = logits[:, :-1]
labels = inputs["labels"][:, 1:] labels = inputs["labels"][:, 1:]
loss = self.hf_compute_loss(labels=labels, logits=logits) loss = self.hf_compute_loss(labels=labels, logits=shifted_logits)
if not inputs["return_dict"]: if not inputs["return_dict"]:
output = (logits,) + outputs[2:] output = (logits,) + outputs[2:]
......
...@@ -1262,9 +1262,9 @@ class TF{{cookiecutter.camelcase_modelname}}ForCausalLM(TF{{cookiecutter.camelca ...@@ -1262,9 +1262,9 @@ class TF{{cookiecutter.camelcase_modelname}}ForCausalLM(TF{{cookiecutter.camelca
if inputs["labels"] is not None: if inputs["labels"] is not None:
# shift labels to the left and cut last logit token # shift labels to the left and cut last logit token
logits = logits[:, :-1] shifted_logits = logits[:, :-1]
labels = inputs["labels"][:, 1:] labels = inputs["labels"][:, 1:]
loss = self.hf_compute_loss(labels=labels, logits=logits) loss = self.hf_compute_loss(labels=labels, logits=shifted_logits)
if not inputs["return_dict"]: if not inputs["return_dict"]:
output = (logits,) + outputs[2:] output = (logits,) + outputs[2:]
......
...@@ -240,7 +240,7 @@ class TFEncoderDecoderMixin: ...@@ -240,7 +240,7 @@ class TFEncoderDecoderMixin:
assert "loss" in outputs_encoder_decoder assert "loss" in outputs_encoder_decoder
batch_size, seq_len = decoder_input_ids.shape batch_size, seq_len = decoder_input_ids.shape
expected_shape = (batch_size, seq_len - 1, decoder_config.vocab_size) expected_shape = (batch_size, seq_len, decoder_config.vocab_size)
self.assertEqual(outputs_encoder_decoder["logits"].shape, expected_shape) self.assertEqual(outputs_encoder_decoder["logits"].shape, expected_shape)
self.assertEqual( self.assertEqual(
outputs_encoder_decoder["encoder_last_hidden_state"].shape, (input_ids.shape + (config.hidden_size,)) outputs_encoder_decoder["encoder_last_hidden_state"].shape, (input_ids.shape + (config.hidden_size,))
......
...@@ -231,7 +231,7 @@ class TFVisionEncoderDecoderMixin: ...@@ -231,7 +231,7 @@ class TFVisionEncoderDecoderMixin:
self.assertIn("loss", outputs_encoder_decoder) self.assertIn("loss", outputs_encoder_decoder)
batch_size, seq_len = decoder_input_ids.shape batch_size, seq_len = decoder_input_ids.shape
expected_shape = (batch_size, seq_len - 1, decoder_config.vocab_size) expected_shape = (batch_size, seq_len, decoder_config.vocab_size)
self.assertEqual(outputs_encoder_decoder["logits"].shape, expected_shape) self.assertEqual(outputs_encoder_decoder["logits"].shape, expected_shape)
self.assertEqual(outputs_encoder_decoder["encoder_last_hidden_state"].shape[0], pixel_values.shape[0]) self.assertEqual(outputs_encoder_decoder["encoder_last_hidden_state"].shape[0], pixel_values.shape[0])
self.assertEqual(outputs_encoder_decoder["encoder_last_hidden_state"].shape[-1], config.hidden_size) self.assertEqual(outputs_encoder_decoder["encoder_last_hidden_state"].shape[-1], config.hidden_size)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment