Unverified Commit a9782881 authored by Partho's avatar Partho Committed by GitHub
Browse files

wrap forward passes with torch.no_grad() (#19273)

parent d6e92044
...@@ -627,7 +627,8 @@ class BigBirdModelIntegrationTest(unittest.TestCase): ...@@ -627,7 +627,8 @@ class BigBirdModelIntegrationTest(unittest.TestCase):
model.to(torch_device) model.to(torch_device)
input_ids = torch.tensor([[20920, 232, 328, 1437] * 1024], dtype=torch.long, device=torch_device) input_ids = torch.tensor([[20920, 232, 328, 1437] * 1024], dtype=torch.long, device=torch_device)
outputs = model(input_ids) with torch.no_grad():
outputs = model(input_ids)
prediction_logits = outputs.prediction_logits prediction_logits = outputs.prediction_logits
seq_relationship_logits = outputs.seq_relationship_logits seq_relationship_logits = outputs.seq_relationship_logits
...@@ -655,7 +656,8 @@ class BigBirdModelIntegrationTest(unittest.TestCase): ...@@ -655,7 +656,8 @@ class BigBirdModelIntegrationTest(unittest.TestCase):
model.to(torch_device) model.to(torch_device)
input_ids = torch.tensor([[20920, 232, 328, 1437] * 512], dtype=torch.long, device=torch_device) input_ids = torch.tensor([[20920, 232, 328, 1437] * 512], dtype=torch.long, device=torch_device)
outputs = model(input_ids) with torch.no_grad():
outputs = model(input_ids)
prediction_logits = outputs.prediction_logits prediction_logits = outputs.prediction_logits
seq_relationship_logits = outputs.seq_relationship_logits seq_relationship_logits = outputs.seq_relationship_logits
...@@ -920,7 +922,8 @@ class BigBirdModelIntegrationTest(unittest.TestCase): ...@@ -920,7 +922,8 @@ class BigBirdModelIntegrationTest(unittest.TestCase):
model.eval() model.eval()
input_ids = torch.tensor([200 * [10] + 40 * [2] + [1]], device=torch_device, dtype=torch.long) input_ids = torch.tensor([200 * [10] + 40 * [2] + [1]], device=torch_device, dtype=torch.long)
output = model(input_ids).to_tuple()[0] with torch.no_grad():
output = model(input_ids).to_tuple()[0]
# fmt: off # fmt: off
target = torch.tensor( target = torch.tensor(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment