Unverified Commit 5f5e264a authored by Partho's avatar Partho Committed by GitHub
Browse files

wrap forward passes with torch.no_grad() (#19413)

parent c6a928ca
...@@ -493,7 +493,8 @@ class FNetModelIntegrationTest(unittest.TestCase): ...@@ -493,7 +493,8 @@ class FNetModelIntegrationTest(unittest.TestCase):
model.to(torch_device) model.to(torch_device)
input_ids = torch.tensor([[0, 1, 2, 3, 4, 5]], device=torch_device) input_ids = torch.tensor([[0, 1, 2, 3, 4, 5]], device=torch_device)
output = model(input_ids)[0] with torch.no_grad():
output = model(input_ids)[0]
vocab_size = 32000 vocab_size = 32000
...@@ -536,7 +537,8 @@ class FNetModelIntegrationTest(unittest.TestCase): ...@@ -536,7 +537,8 @@ class FNetModelIntegrationTest(unittest.TestCase):
model.to(torch_device) model.to(torch_device)
input_ids = torch.tensor([[0, 1, 2, 3, 4, 5]], device=torch_device) input_ids = torch.tensor([[0, 1, 2, 3, 4, 5]], device=torch_device)
output = model(input_ids)[0] with torch.no_grad():
output = model(input_ids)[0]
expected_shape = torch.Size((1, 2)) expected_shape = torch.Size((1, 2))
self.assertEqual(output.shape, expected_shape) self.assertEqual(output.shape, expected_shape)
...@@ -551,7 +553,8 @@ class FNetModelIntegrationTest(unittest.TestCase): ...@@ -551,7 +553,8 @@ class FNetModelIntegrationTest(unittest.TestCase):
model.to(torch_device) model.to(torch_device)
input_ids = torch.tensor([[0, 1, 2, 3, 4, 5]], device=torch_device) input_ids = torch.tensor([[0, 1, 2, 3, 4, 5]], device=torch_device)
output = model(input_ids)[0] with torch.no_grad():
output = model(input_ids)[0]
expected_shape = torch.Size((1, 6, model.config.hidden_size)) expected_shape = torch.Size((1, 6, model.config.hidden_size))
self.assertEqual(output.shape, expected_shape) self.assertEqual(output.shape, expected_shape)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment