"...core/git@developer.sourcefind.cn:OpenDAS/megatron-lm.git" did not exist on "f00f025605d435c4c95f9163afa7019fee23b7c7"
Unverified Commit cc1d0685 authored by Katie Le's avatar Katie Le Committed by GitHub
Browse files

Wrap RemBert integration test forward passes with torch.no_grad() (#21503)



added with torch.no_grad() to the integration tests and applied make style
Co-authored-by: default avatarBibi <Bibi@katies-mac.local>
parent 5b67ab99
......@@ -464,7 +464,8 @@ class RemBertModelIntegrationTest(unittest.TestCase):
model = RemBertModel.from_pretrained("google/rembert")
input_ids = torch.tensor([[312, 56498, 313, 2125, 313]])
segment_ids = torch.tensor([[0, 0, 0, 1, 1]])
output = model(input_ids, token_type_ids=segment_ids, output_hidden_states=True)
with torch.no_grad():
output = model(input_ids, token_type_ids=segment_ids, output_hidden_states=True)
hidden_size = 1152
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment