Unverified Commit 6bf88537 authored by Kian Sierra McGettigan's avatar Kian Sierra McGettigan Committed by GitHub
Browse files

Prophetnet batch dimension inversion fix (#21870)

* decoder forward pass is working

* no model has forward pass returning attentions

* decoder ngram changed to not mix batch size

* current basic forward pass returns identical result

* passed test_model attentions

* passed test_encoder_decoder_model_generate

* passed test_headmasking

* removed old block

* removed comments bug/fixme

* removed bug comments

* applied styling

* applied fix-copies

* applied ngram forward comments

* corrected dimension notation

* applied styling and comment fixes

* changed asserts for raise ValueError

* changed question gen test

* updated hidden_states integration test

* applied styling
parent 99ba36e7
...@@ -1206,7 +1206,7 @@ class ProphetNetModelIntegrationTest(unittest.TestCase): ...@@ -1206,7 +1206,7 @@ class ProphetNetModelIntegrationTest(unittest.TestCase):
expected_shape = torch.Size((1, 12, 30522)) expected_shape = torch.Size((1, 12, 30522))
self.assertEqual(output_predited_logits.shape, expected_shape) self.assertEqual(output_predited_logits.shape, expected_shape)
expected_slice = torch.tensor( expected_slice = torch.tensor(
[[[-7.6213, -7.9008, -7.9979], [-7.6834, -7.8467, -8.2187], [-7.5326, -7.4762, -8.1914]]] [[[-7.7729, -8.0343, -8.26001], [-7.74213, -7.8629, -8.6000], [-7.7328, -7.8269, -8.5264]]]
).to(torch_device) ).to(torch_device)
# self.assertTrue(torch.allclose(output_predited_logits[:, :3, :3], expected_slice, atol=1e-4)) # self.assertTrue(torch.allclose(output_predited_logits[:, :3, :3], expected_slice, atol=1e-4))
assert torch.allclose(output_predited_logits[:, :3, :3], expected_slice, atol=1e-4) assert torch.allclose(output_predited_logits[:, :3, :3], expected_slice, atol=1e-4)
...@@ -1306,7 +1306,7 @@ class ProphetNetModelIntegrationTest(unittest.TestCase): ...@@ -1306,7 +1306,7 @@ class ProphetNetModelIntegrationTest(unittest.TestCase):
EXPECTED_QUESTIONS = [ EXPECTED_QUESTIONS = [
"along with paul allen, who founded microsoft?", "along with paul allen, who founded microsoft?",
"what year was microsoft founded?", "what year was microsoft founded?",
"on what date was microsoft founded?", "when was microsoft founded?",
] ]
self.assertListEqual( self.assertListEqual(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment