Unverified Commit 23c146c3 authored by Katie Le's avatar Katie Le Committed by GitHub
Browse files

Added with torch.no_grad() to XLM-Roberta integration test (#21547)



* added with torch.no_grad() to the integration tests and applied make style

* added with torch.no_grad() to xlm roberta forward pass

---------
Co-authored-by: default avatarBibi <Bibi@katies-mac.local>
parent 04b2f13c
......@@ -43,8 +43,8 @@ class XLMRobertaModelIntegrationTest(unittest.TestCase):
# xlmr = torch.hub.load('pytorch/fairseq', 'xlmr.base')
# xlmr.eval()
# expected_output_values_last_dim = xlmr.extract_features(input_ids[0])[:, :, -1]
output = model(input_ids)["last_hidden_state"].detach()
with torch.no_grad():
output = model(input_ids)["last_hidden_state"].detach()
self.assertEqual(output.shape, expected_output_shape)
# compare the actual values for a slice of last dim
self.assertTrue(torch.allclose(output[:, :, -1], expected_output_values_last_dim, atol=1e-3))
......@@ -62,8 +62,8 @@ class XLMRobertaModelIntegrationTest(unittest.TestCase):
# xlmr = torch.hub.load('pytorch/fairseq', 'xlmr.large')
# xlmr.eval()
# expected_output_values_last_dim = xlmr.extract_features(input_ids[0])[:, :, -1]
output = model(input_ids)["last_hidden_state"].detach()
with torch.no_grad():
output = model(input_ids)["last_hidden_state"].detach()
self.assertEqual(output.shape, expected_output_shape)
# compare the actual values for a slice of last dim
self.assertTrue(torch.allclose(output[:, :, -1], expected_output_values_last_dim, atol=1e-3))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment