Unverified Commit d739a707 authored by Partho's avatar Partho Committed by GitHub
Browse files

wrap forward passes with torch.no_grad() (#19416)

parent 870a9542
......@@ -570,6 +570,7 @@ class TapasModelIntegrationTest(unittest.TestCase):
table, queries = prepare_tapas_single_inputs_for_inference()
inputs = tokenizer(table=table, queries=queries, return_tensors="pt")
inputs = {k: v.to(torch_device) for k, v in inputs.items()}
with torch.no_grad():
outputs = model(**inputs)
# test the sequence output
expected_slice = torch.tensor(
......@@ -608,6 +609,7 @@ class TapasModelIntegrationTest(unittest.TestCase):
table, queries = prepare_tapas_single_inputs_for_inference()
inputs = tokenizer(table=table, queries=queries, return_tensors="pt")
inputs = {k: v.to(torch_device) for k, v in inputs.items()}
with torch.no_grad():
outputs = model(**inputs)
# test the logits
logits = outputs.logits
......@@ -657,6 +659,7 @@ class TapasModelIntegrationTest(unittest.TestCase):
table, queries = prepare_tapas_single_inputs_for_inference()
inputs = tokenizer(table=table, queries=queries, return_tensors="pt")
inputs = {k: v.to(torch_device) for k, v in inputs.items()}
with torch.no_grad():
outputs = model(**inputs)
# test the logits
logits = outputs.logits
......@@ -705,6 +708,7 @@ class TapasModelIntegrationTest(unittest.TestCase):
inputs = tokenizer(table=table, queries=queries, padding="longest", return_tensors="pt")
inputs_on_device = {k: v.to(torch_device) for k, v in inputs.items()}
with torch.no_grad():
outputs = model(**inputs_on_device)
# test the logits
logits = outputs.logits
......@@ -774,6 +778,7 @@ class TapasModelIntegrationTest(unittest.TestCase):
float_answer = torch.FloatTensor(float_answer).to(torch_device)
# forward pass to get loss + logits:
with torch.no_grad():
outputs = model(
input_ids=input_ids,
attention_mask=attention_mask,
......@@ -829,6 +834,7 @@ class TapasModelIntegrationTest(unittest.TestCase):
table, queries = prepare_tapas_single_inputs_for_inference()
inputs = tokenizer(table=table, queries=queries, return_tensors="pt")
inputs = {k: v.to(torch_device) for k, v in inputs.items()}
with torch.no_grad():
outputs = model(**inputs)
# test the logits
logits = outputs.logits
......@@ -884,6 +890,7 @@ class TapasModelIntegrationTest(unittest.TestCase):
table, queries = prepare_tapas_single_inputs_for_inference()
inputs = tokenizer(table=table, queries=queries, padding="longest", return_tensors="pt")
inputs = {k: v.to(torch_device) for k, v in inputs.items()}
with torch.no_grad():
outputs = model(**inputs)
# test the classification logits
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment