Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
a9782881
"llama/git@developer.sourcefind.cn:OpenDAS/ollama.git" did not exist on "d25efe395415353152feafb9b766a8a7a1496517"
Unverified
Commit
a9782881
authored
Oct 04, 2022
by
Partho
Committed by
GitHub
Oct 04, 2022
Browse files
wrap forward passes with torch.no_grad() (#19273)
parent
d6e92044
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
6 additions
and
3 deletions
+6
-3
tests/models/big_bird/test_modeling_big_bird.py
tests/models/big_bird/test_modeling_big_bird.py
+6
-3
No files found.
tests/models/big_bird/test_modeling_big_bird.py
View file @
a9782881
...
@@ -627,7 +627,8 @@ class BigBirdModelIntegrationTest(unittest.TestCase):
...
@@ -627,7 +627,8 @@ class BigBirdModelIntegrationTest(unittest.TestCase):
model
.
to
(
torch_device
)
model
.
to
(
torch_device
)
input_ids
=
torch
.
tensor
([[
20920
,
232
,
328
,
1437
]
*
1024
],
dtype
=
torch
.
long
,
device
=
torch_device
)
input_ids
=
torch
.
tensor
([[
20920
,
232
,
328
,
1437
]
*
1024
],
dtype
=
torch
.
long
,
device
=
torch_device
)
outputs
=
model
(
input_ids
)
with
torch
.
no_grad
():
outputs
=
model
(
input_ids
)
prediction_logits
=
outputs
.
prediction_logits
prediction_logits
=
outputs
.
prediction_logits
seq_relationship_logits
=
outputs
.
seq_relationship_logits
seq_relationship_logits
=
outputs
.
seq_relationship_logits
...
@@ -655,7 +656,8 @@ class BigBirdModelIntegrationTest(unittest.TestCase):
...
@@ -655,7 +656,8 @@ class BigBirdModelIntegrationTest(unittest.TestCase):
model
.
to
(
torch_device
)
model
.
to
(
torch_device
)
input_ids
=
torch
.
tensor
([[
20920
,
232
,
328
,
1437
]
*
512
],
dtype
=
torch
.
long
,
device
=
torch_device
)
input_ids
=
torch
.
tensor
([[
20920
,
232
,
328
,
1437
]
*
512
],
dtype
=
torch
.
long
,
device
=
torch_device
)
outputs
=
model
(
input_ids
)
with
torch
.
no_grad
():
outputs
=
model
(
input_ids
)
prediction_logits
=
outputs
.
prediction_logits
prediction_logits
=
outputs
.
prediction_logits
seq_relationship_logits
=
outputs
.
seq_relationship_logits
seq_relationship_logits
=
outputs
.
seq_relationship_logits
...
@@ -920,7 +922,8 @@ class BigBirdModelIntegrationTest(unittest.TestCase):
...
@@ -920,7 +922,8 @@ class BigBirdModelIntegrationTest(unittest.TestCase):
model
.
eval
()
model
.
eval
()
input_ids
=
torch
.
tensor
([
200
*
[
10
]
+
40
*
[
2
]
+
[
1
]],
device
=
torch_device
,
dtype
=
torch
.
long
)
input_ids
=
torch
.
tensor
([
200
*
[
10
]
+
40
*
[
2
]
+
[
1
]],
device
=
torch_device
,
dtype
=
torch
.
long
)
output
=
model
(
input_ids
).
to_tuple
()[
0
]
with
torch
.
no_grad
():
output
=
model
(
input_ids
).
to_tuple
()[
0
]
# fmt: off
# fmt: off
target
=
torch
.
tensor
(
target
=
torch
.
tensor
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment