Unverified Commit 2a91a9ef authored by Matt's avatar Matt Committed by GitHub
Browse files

Fix PT-TF equivalence test for GPT1 (#22586)

* Re-enable skipped test and fix the hidden state shape issue

* Actually fix the bug instead of just doing something wrong
parent 06842849
...@@ -748,6 +748,12 @@ class TFOpenAIGPTDoubleHeadsModel(TFOpenAIGPTPreTrainedModel): ...@@ -748,6 +748,12 @@ class TFOpenAIGPTDoubleHeadsModel(TFOpenAIGPTPreTrainedModel):
) )
hidden_states = transformer_outputs[0] hidden_states = transformer_outputs[0]
hidden_states = tf.reshape(hidden_states, input_shapes + shape_list(hidden_states)[-1:]) hidden_states = tf.reshape(hidden_states, input_shapes + shape_list(hidden_states)[-1:])
if return_dict and output_hidden_states:
# We do this to match the slightly odd PT behaviour - the final hidden state is reshaped to rank 4 when the
# input is rank 3, but all other hidden states remain at rank-3 (with the first 2 dims merged)
all_hidden_states = transformer_outputs.hidden_states[:-1] + (hidden_states,)
else:
all_hidden_states = None
lm_logits = self.transformer.tokens_embed(hidden_states, mode="linear") lm_logits = self.transformer.tokens_embed(hidden_states, mode="linear")
mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids, training=training) mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids, training=training)
mc_logits = tf.squeeze(mc_logits, axis=-1) mc_logits = tf.squeeze(mc_logits, axis=-1)
...@@ -758,7 +764,7 @@ class TFOpenAIGPTDoubleHeadsModel(TFOpenAIGPTPreTrainedModel): ...@@ -758,7 +764,7 @@ class TFOpenAIGPTDoubleHeadsModel(TFOpenAIGPTPreTrainedModel):
return TFOpenAIGPTDoubleHeadsModelOutput( return TFOpenAIGPTDoubleHeadsModelOutput(
logits=lm_logits, logits=lm_logits,
mc_logits=mc_logits, mc_logits=mc_logits,
hidden_states=transformer_outputs.hidden_states, hidden_states=all_hidden_states,
attentions=transformer_outputs.attentions, attentions=transformer_outputs.attentions,
) )
......
...@@ -274,10 +274,6 @@ class OpenAIGPTModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester ...@@ -274,10 +274,6 @@ class OpenAIGPTModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
model = OpenAIGPTModel.from_pretrained(model_name) model = OpenAIGPTModel.from_pretrained(model_name)
self.assertIsNotNone(model) self.assertIsNotNone(model)
@unittest.skip("Fix me Matt")
def test_pt_tf_model_equivalence(self):
pass
@require_torch @require_torch
class OPENAIGPTModelLanguageGenerationTest(unittest.TestCase): class OPENAIGPTModelLanguageGenerationTest(unittest.TestCase):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment