Unverified Commit cc1d0685 authored by Katie Le's avatar Katie Le Committed by GitHub
Browse files

Wrap RemBert integration test forward passes with torch.no_grad() (#21503)



added with torch.no_grad() to the integration tests and applied make style
Co-authored-by: default avatarBibi <Bibi@katies-mac.local>
parent 5b67ab99
......@@ -464,6 +464,7 @@ class RemBertModelIntegrationTest(unittest.TestCase):
model = RemBertModel.from_pretrained("google/rembert")
input_ids = torch.tensor([[312, 56498, 313, 2125, 313]])
segment_ids = torch.tensor([[0, 0, 0, 1, 1]])
with torch.no_grad():
output = model(input_ids, token_type_ids=segment_ids, output_hidden_states=True)
hidden_size = 1152
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment