Unverified Commit f71fb5c3 authored by Tavin Turner's avatar Tavin Turner Committed by GitHub
Browse files

Add 'with torch.no_grad()' to BertGeneration integration test forward passes (#14963)

parent d2183a46
...@@ -307,6 +307,7 @@ class BertGenerationEncoderIntegrationTest(unittest.TestCase): ...@@ -307,6 +307,7 @@ class BertGenerationEncoderIntegrationTest(unittest.TestCase):
def test_inference_no_head_absolute_embedding(self): def test_inference_no_head_absolute_embedding(self):
model = BertGenerationEncoder.from_pretrained("google/bert_for_seq_generation_L-24_bbc_encoder") model = BertGenerationEncoder.from_pretrained("google/bert_for_seq_generation_L-24_bbc_encoder")
input_ids = torch.tensor([[101, 7592, 1010, 2026, 3899, 2003, 10140, 102]]) input_ids = torch.tensor([[101, 7592, 1010, 2026, 3899, 2003, 10140, 102]])
with torch.no_grad():
output = model(input_ids)[0] output = model(input_ids)[0]
expected_shape = torch.Size([1, 8, 1024]) expected_shape = torch.Size([1, 8, 1024])
self.assertEqual(output.shape, expected_shape) self.assertEqual(output.shape, expected_shape)
...@@ -322,6 +323,7 @@ class BertGenerationDecoderIntegrationTest(unittest.TestCase): ...@@ -322,6 +323,7 @@ class BertGenerationDecoderIntegrationTest(unittest.TestCase):
def test_inference_no_head_absolute_embedding(self): def test_inference_no_head_absolute_embedding(self):
model = BertGenerationDecoder.from_pretrained("google/bert_for_seq_generation_L-24_bbc_encoder") model = BertGenerationDecoder.from_pretrained("google/bert_for_seq_generation_L-24_bbc_encoder")
input_ids = torch.tensor([[101, 7592, 1010, 2026, 3899, 2003, 10140, 102]]) input_ids = torch.tensor([[101, 7592, 1010, 2026, 3899, 2003, 10140, 102]])
with torch.no_grad():
output = model(input_ids)[0] output = model(input_ids)[0]
expected_shape = torch.Size([1, 8, 50358]) expected_shape = torch.Size([1, 8, 50358])
self.assertEqual(output.shape, expected_shape) self.assertEqual(output.shape, expected_shape)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment