Unverified Commit cc1d0685 authored by Katie Le's avatar Katie Le Committed by GitHub
Browse files

Wrap RemBert integration test forward passes with torch.no_grad() (#21503)



added with torch.no_grad() to the integration tests and applied make style
Co-authored-by: default avatarBibi <Bibi@katies-mac.local>
parent 5b67ab99
...@@ -464,6 +464,7 @@ class RemBertModelIntegrationTest(unittest.TestCase): ...@@ -464,6 +464,7 @@ class RemBertModelIntegrationTest(unittest.TestCase):
model = RemBertModel.from_pretrained("google/rembert") model = RemBertModel.from_pretrained("google/rembert")
input_ids = torch.tensor([[312, 56498, 313, 2125, 313]]) input_ids = torch.tensor([[312, 56498, 313, 2125, 313]])
segment_ids = torch.tensor([[0, 0, 0, 1, 1]]) segment_ids = torch.tensor([[0, 0, 0, 1, 1]])
with torch.no_grad():
output = model(input_ids, token_type_ids=segment_ids, output_hidden_states=True) output = model(input_ids, token_type_ids=segment_ids, output_hidden_states=True)
hidden_size = 1152 hidden_size = 1152
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment