Unverified Commit 692c5be7 authored by Partho's avatar Partho Committed by GitHub
Browse files

wrap forward passes with torch.no_grad() (#19439)

parent a7bc4221
...@@ -568,14 +568,15 @@ class VisualBertModelIntegrationTest(unittest.TestCase): ...@@ -568,14 +568,15 @@ class VisualBertModelIntegrationTest(unittest.TestCase):
attention_mask = torch.tensor([1] * 6).reshape(1, -1) attention_mask = torch.tensor([1] * 6).reshape(1, -1)
visual_attention_mask = torch.tensor([1] * 10).reshape(1, -1) visual_attention_mask = torch.tensor([1] * 10).reshape(1, -1)
output = model( with torch.no_grad():
input_ids=input_ids, output = model(
attention_mask=attention_mask, input_ids=input_ids,
token_type_ids=token_type_ids, attention_mask=attention_mask,
visual_embeds=visual_embeds, token_type_ids=token_type_ids,
visual_attention_mask=visual_attention_mask, visual_embeds=visual_embeds,
visual_token_type_ids=visual_token_type_ids, visual_attention_mask=visual_attention_mask,
) visual_token_type_ids=visual_token_type_ids,
)
vocab_size = 30522 vocab_size = 30522
...@@ -606,14 +607,15 @@ class VisualBertModelIntegrationTest(unittest.TestCase): ...@@ -606,14 +607,15 @@ class VisualBertModelIntegrationTest(unittest.TestCase):
attention_mask = torch.tensor([1] * 6).reshape(1, -1) attention_mask = torch.tensor([1] * 6).reshape(1, -1)
visual_attention_mask = torch.tensor([1] * 10).reshape(1, -1) visual_attention_mask = torch.tensor([1] * 10).reshape(1, -1)
output = model( with torch.no_grad():
input_ids=input_ids, output = model(
attention_mask=attention_mask, input_ids=input_ids,
token_type_ids=token_type_ids, attention_mask=attention_mask,
visual_embeds=visual_embeds, token_type_ids=token_type_ids,
visual_attention_mask=visual_attention_mask, visual_embeds=visual_embeds,
visual_token_type_ids=visual_token_type_ids, visual_attention_mask=visual_attention_mask,
) visual_token_type_ids=visual_token_type_ids,
)
# vocab_size = 30522 # vocab_size = 30522
...@@ -637,14 +639,15 @@ class VisualBertModelIntegrationTest(unittest.TestCase): ...@@ -637,14 +639,15 @@ class VisualBertModelIntegrationTest(unittest.TestCase):
attention_mask = torch.tensor([1] * 6).reshape(1, -1) attention_mask = torch.tensor([1] * 6).reshape(1, -1)
visual_attention_mask = torch.tensor([1] * 10).reshape(1, -1) visual_attention_mask = torch.tensor([1] * 10).reshape(1, -1)
output = model( with torch.no_grad():
input_ids=input_ids, output = model(
attention_mask=attention_mask, input_ids=input_ids,
token_type_ids=token_type_ids, attention_mask=attention_mask,
visual_embeds=visual_embeds, token_type_ids=token_type_ids,
visual_attention_mask=visual_attention_mask, visual_embeds=visual_embeds,
visual_token_type_ids=visual_token_type_ids, visual_attention_mask=visual_attention_mask,
) visual_token_type_ids=visual_token_type_ids,
)
# vocab_size = 30522 # vocab_size = 30522
...@@ -667,14 +670,15 @@ class VisualBertModelIntegrationTest(unittest.TestCase): ...@@ -667,14 +670,15 @@ class VisualBertModelIntegrationTest(unittest.TestCase):
visual_token_type_ids = torch.ones(size=(1, 4, 10), dtype=torch.long) visual_token_type_ids = torch.ones(size=(1, 4, 10), dtype=torch.long)
visual_attention_mask = torch.ones_like(visual_token_type_ids) visual_attention_mask = torch.ones_like(visual_token_type_ids)
output = model( with torch.no_grad():
input_ids=input_ids, output = model(
attention_mask=attention_mask, input_ids=input_ids,
token_type_ids=token_type_ids, attention_mask=attention_mask,
visual_embeds=visual_embeds, token_type_ids=token_type_ids,
visual_attention_mask=visual_attention_mask, visual_embeds=visual_embeds,
visual_token_type_ids=visual_token_type_ids, visual_attention_mask=visual_attention_mask,
) visual_token_type_ids=visual_token_type_ids,
)
# vocab_size = 30522 # vocab_size = 30522
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment