"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "7addc9346c89563c0d36b30fa3534c58d3a1de05"
Unverified Commit d739a707 authored by Partho's avatar Partho Committed by GitHub
Browse files

wrap forward passes with torch.no_grad() (#19416)

parent 870a9542
...@@ -570,7 +570,8 @@ class TapasModelIntegrationTest(unittest.TestCase): ...@@ -570,7 +570,8 @@ class TapasModelIntegrationTest(unittest.TestCase):
table, queries = prepare_tapas_single_inputs_for_inference() table, queries = prepare_tapas_single_inputs_for_inference()
inputs = tokenizer(table=table, queries=queries, return_tensors="pt") inputs = tokenizer(table=table, queries=queries, return_tensors="pt")
inputs = {k: v.to(torch_device) for k, v in inputs.items()} inputs = {k: v.to(torch_device) for k, v in inputs.items()}
outputs = model(**inputs) with torch.no_grad():
outputs = model(**inputs)
# test the sequence output # test the sequence output
expected_slice = torch.tensor( expected_slice = torch.tensor(
[ [
...@@ -608,7 +609,8 @@ class TapasModelIntegrationTest(unittest.TestCase): ...@@ -608,7 +609,8 @@ class TapasModelIntegrationTest(unittest.TestCase):
table, queries = prepare_tapas_single_inputs_for_inference() table, queries = prepare_tapas_single_inputs_for_inference()
inputs = tokenizer(table=table, queries=queries, return_tensors="pt") inputs = tokenizer(table=table, queries=queries, return_tensors="pt")
inputs = {k: v.to(torch_device) for k, v in inputs.items()} inputs = {k: v.to(torch_device) for k, v in inputs.items()}
outputs = model(**inputs) with torch.no_grad():
outputs = model(**inputs)
# test the logits # test the logits
logits = outputs.logits logits = outputs.logits
expected_shape = torch.Size((1, 21)) expected_shape = torch.Size((1, 21))
...@@ -657,7 +659,8 @@ class TapasModelIntegrationTest(unittest.TestCase): ...@@ -657,7 +659,8 @@ class TapasModelIntegrationTest(unittest.TestCase):
table, queries = prepare_tapas_single_inputs_for_inference() table, queries = prepare_tapas_single_inputs_for_inference()
inputs = tokenizer(table=table, queries=queries, return_tensors="pt") inputs = tokenizer(table=table, queries=queries, return_tensors="pt")
inputs = {k: v.to(torch_device) for k, v in inputs.items()} inputs = {k: v.to(torch_device) for k, v in inputs.items()}
outputs = model(**inputs) with torch.no_grad():
outputs = model(**inputs)
# test the logits # test the logits
logits = outputs.logits logits = outputs.logits
expected_shape = torch.Size((1, 21)) expected_shape = torch.Size((1, 21))
...@@ -705,7 +708,8 @@ class TapasModelIntegrationTest(unittest.TestCase): ...@@ -705,7 +708,8 @@ class TapasModelIntegrationTest(unittest.TestCase):
inputs = tokenizer(table=table, queries=queries, padding="longest", return_tensors="pt") inputs = tokenizer(table=table, queries=queries, padding="longest", return_tensors="pt")
inputs_on_device = {k: v.to(torch_device) for k, v in inputs.items()} inputs_on_device = {k: v.to(torch_device) for k, v in inputs.items()}
outputs = model(**inputs_on_device) with torch.no_grad():
outputs = model(**inputs_on_device)
# test the logits # test the logits
logits = outputs.logits logits = outputs.logits
expected_shape = torch.Size((2, 28)) expected_shape = torch.Size((2, 28))
...@@ -774,15 +778,16 @@ class TapasModelIntegrationTest(unittest.TestCase): ...@@ -774,15 +778,16 @@ class TapasModelIntegrationTest(unittest.TestCase):
float_answer = torch.FloatTensor(float_answer).to(torch_device) float_answer = torch.FloatTensor(float_answer).to(torch_device)
# forward pass to get loss + logits: # forward pass to get loss + logits:
outputs = model( with torch.no_grad():
input_ids=input_ids, outputs = model(
attention_mask=attention_mask, input_ids=input_ids,
token_type_ids=token_type_ids, attention_mask=attention_mask,
labels=labels, token_type_ids=token_type_ids,
numeric_values=numeric_values, labels=labels,
numeric_values_scale=numeric_values_scale, numeric_values=numeric_values,
float_answer=float_answer, numeric_values_scale=numeric_values_scale,
) float_answer=float_answer,
)
# test the loss # test the loss
loss = outputs.loss loss = outputs.loss
...@@ -829,7 +834,8 @@ class TapasModelIntegrationTest(unittest.TestCase): ...@@ -829,7 +834,8 @@ class TapasModelIntegrationTest(unittest.TestCase):
table, queries = prepare_tapas_single_inputs_for_inference() table, queries = prepare_tapas_single_inputs_for_inference()
inputs = tokenizer(table=table, queries=queries, return_tensors="pt") inputs = tokenizer(table=table, queries=queries, return_tensors="pt")
inputs = {k: v.to(torch_device) for k, v in inputs.items()} inputs = {k: v.to(torch_device) for k, v in inputs.items()}
outputs = model(**inputs) with torch.no_grad():
outputs = model(**inputs)
# test the logits # test the logits
logits = outputs.logits logits = outputs.logits
expected_shape = torch.Size((1, 21)) expected_shape = torch.Size((1, 21))
...@@ -884,7 +890,8 @@ class TapasModelIntegrationTest(unittest.TestCase): ...@@ -884,7 +890,8 @@ class TapasModelIntegrationTest(unittest.TestCase):
table, queries = prepare_tapas_single_inputs_for_inference() table, queries = prepare_tapas_single_inputs_for_inference()
inputs = tokenizer(table=table, queries=queries, padding="longest", return_tensors="pt") inputs = tokenizer(table=table, queries=queries, padding="longest", return_tensors="pt")
inputs = {k: v.to(torch_device) for k, v in inputs.items()} inputs = {k: v.to(torch_device) for k, v in inputs.items()}
outputs = model(**inputs) with torch.no_grad():
outputs = model(**inputs)
# test the classification logits # test the classification logits
logits = outputs.logits logits = outputs.logits
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment