Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
692c5be7
"docs/source/zh/main_classes/quantization.md" did not exist on "74a3cebfa51b539bfcfa79b33686cc090b7074e8"
Unverified
Commit
692c5be7
authored
Oct 11, 2022
by
Partho
Committed by
GitHub
Oct 10, 2022
Browse files
wrap forward passes with torch.no_grad() (#19439)
parent
a7bc4221
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
36 additions
and
32 deletions
+36
-32
tests/models/visual_bert/test_modeling_visual_bert.py
tests/models/visual_bert/test_modeling_visual_bert.py
+36
-32
No files found.
tests/models/visual_bert/test_modeling_visual_bert.py
View file @
692c5be7
...
...
@@ -568,14 +568,15 @@ class VisualBertModelIntegrationTest(unittest.TestCase):
attention_mask
=
torch
.
tensor
([
1
]
*
6
).
reshape
(
1
,
-
1
)
visual_attention_mask
=
torch
.
tensor
([
1
]
*
10
).
reshape
(
1
,
-
1
)
output
=
model
(
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
token_type_ids
=
token_type_ids
,
visual_embeds
=
visual_embeds
,
visual_attention_mask
=
visual_attention_mask
,
visual_token_type_ids
=
visual_token_type_ids
,
)
with
torch
.
no_grad
():
output
=
model
(
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
token_type_ids
=
token_type_ids
,
visual_embeds
=
visual_embeds
,
visual_attention_mask
=
visual_attention_mask
,
visual_token_type_ids
=
visual_token_type_ids
,
)
vocab_size
=
30522
...
...
@@ -606,14 +607,15 @@ class VisualBertModelIntegrationTest(unittest.TestCase):
attention_mask
=
torch
.
tensor
([
1
]
*
6
).
reshape
(
1
,
-
1
)
visual_attention_mask
=
torch
.
tensor
([
1
]
*
10
).
reshape
(
1
,
-
1
)
output
=
model
(
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
token_type_ids
=
token_type_ids
,
visual_embeds
=
visual_embeds
,
visual_attention_mask
=
visual_attention_mask
,
visual_token_type_ids
=
visual_token_type_ids
,
)
with
torch
.
no_grad
():
output
=
model
(
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
token_type_ids
=
token_type_ids
,
visual_embeds
=
visual_embeds
,
visual_attention_mask
=
visual_attention_mask
,
visual_token_type_ids
=
visual_token_type_ids
,
)
# vocab_size = 30522
...
...
@@ -637,14 +639,15 @@ class VisualBertModelIntegrationTest(unittest.TestCase):
attention_mask
=
torch
.
tensor
([
1
]
*
6
).
reshape
(
1
,
-
1
)
visual_attention_mask
=
torch
.
tensor
([
1
]
*
10
).
reshape
(
1
,
-
1
)
output
=
model
(
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
token_type_ids
=
token_type_ids
,
visual_embeds
=
visual_embeds
,
visual_attention_mask
=
visual_attention_mask
,
visual_token_type_ids
=
visual_token_type_ids
,
)
with
torch
.
no_grad
():
output
=
model
(
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
token_type_ids
=
token_type_ids
,
visual_embeds
=
visual_embeds
,
visual_attention_mask
=
visual_attention_mask
,
visual_token_type_ids
=
visual_token_type_ids
,
)
# vocab_size = 30522
...
...
@@ -667,14 +670,15 @@ class VisualBertModelIntegrationTest(unittest.TestCase):
visual_token_type_ids
=
torch
.
ones
(
size
=
(
1
,
4
,
10
),
dtype
=
torch
.
long
)
visual_attention_mask
=
torch
.
ones_like
(
visual_token_type_ids
)
output
=
model
(
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
token_type_ids
=
token_type_ids
,
visual_embeds
=
visual_embeds
,
visual_attention_mask
=
visual_attention_mask
,
visual_token_type_ids
=
visual_token_type_ids
,
)
with
torch
.
no_grad
():
output
=
model
(
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
token_type_ids
=
token_type_ids
,
visual_embeds
=
visual_embeds
,
visual_attention_mask
=
visual_attention_mask
,
visual_token_type_ids
=
visual_token_type_ids
,
)
# vocab_size = 30522
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment