Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
bdd690a7
Unverified
Commit
bdd690a7
authored
May 02, 2022
by
yujun
Committed by
GitHub
May 02, 2022
Browse files
add torch.no_grad when in eval mode (#17020)
* add torch.no_grad when in eval mode * make style quality
parent
9586e222
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
10 additions
and
4 deletions
+10
-4
examples/pytorch/image-classification/run_image_classification_no_trainer.py
...age-classification/run_image_classification_no_trainer.py
+2
-1
examples/pytorch/semantic-segmentation/run_semantic_segmentation_no_trainer.py
...ntic-segmentation/run_semantic_segmentation_no_trainer.py
+2
-1
examples/pytorch/text-classification/run_glue_no_trainer.py
examples/pytorch/text-classification/run_glue_no_trainer.py
+3
-1
templates/adding_a_new_example_script/{{cookiecutter.directory_name}}/run_{{cookiecutter.example_shortcut}}.py
...directory_name}}/run_{{cookiecutter.example_shortcut}}.py
+3
-1
No files found.
examples/pytorch/image-classification/run_image_classification_no_trainer.py
View file @
bdd690a7
...
@@ -469,7 +469,8 @@ def main():
...
@@ -469,7 +469,8 @@ def main():
model
.
eval
()
model
.
eval
()
samples_seen
=
0
samples_seen
=
0
for
step
,
batch
in
enumerate
(
eval_dataloader
):
for
step
,
batch
in
enumerate
(
eval_dataloader
):
outputs
=
model
(
**
batch
)
with
torch
.
no_grad
():
outputs
=
model
(
**
batch
)
predictions
=
outputs
.
logits
.
argmax
(
dim
=-
1
)
predictions
=
outputs
.
logits
.
argmax
(
dim
=-
1
)
predictions
,
references
=
accelerator
.
gather
((
predictions
,
batch
[
"labels"
]))
predictions
,
references
=
accelerator
.
gather
((
predictions
,
batch
[
"labels"
]))
# If we are in a multiprocess environment, the last batch has duplicates
# If we are in a multiprocess environment, the last batch has duplicates
...
...
examples/pytorch/semantic-segmentation/run_semantic_segmentation_no_trainer.py
View file @
bdd690a7
...
@@ -579,7 +579,8 @@ def main():
...
@@ -579,7 +579,8 @@ def main():
model
.
eval
()
model
.
eval
()
samples_seen
=
0
samples_seen
=
0
for
step
,
batch
in
enumerate
(
tqdm
(
eval_dataloader
,
disable
=
not
accelerator
.
is_local_main_process
)):
for
step
,
batch
in
enumerate
(
tqdm
(
eval_dataloader
,
disable
=
not
accelerator
.
is_local_main_process
)):
outputs
=
model
(
**
batch
)
with
torch
.
no_grad
():
outputs
=
model
(
**
batch
)
upsampled_logits
=
torch
.
nn
.
functional
.
interpolate
(
upsampled_logits
=
torch
.
nn
.
functional
.
interpolate
(
outputs
.
logits
,
size
=
batch
[
"labels"
].
shape
[
-
2
:],
mode
=
"bilinear"
,
align_corners
=
False
outputs
.
logits
,
size
=
batch
[
"labels"
].
shape
[
-
2
:],
mode
=
"bilinear"
,
align_corners
=
False
...
...
examples/pytorch/text-classification/run_glue_no_trainer.py
View file @
bdd690a7
...
@@ -22,6 +22,7 @@ import random
...
@@ -22,6 +22,7 @@ import random
from
pathlib
import
Path
from
pathlib
import
Path
import
datasets
import
datasets
import
torch
from
datasets
import
load_dataset
,
load_metric
from
datasets
import
load_dataset
,
load_metric
from
torch.utils.data
import
DataLoader
from
torch.utils.data
import
DataLoader
from
tqdm.auto
import
tqdm
from
tqdm.auto
import
tqdm
...
@@ -514,7 +515,8 @@ def main():
...
@@ -514,7 +515,8 @@ def main():
model
.
eval
()
model
.
eval
()
samples_seen
=
0
samples_seen
=
0
for
step
,
batch
in
enumerate
(
eval_dataloader
):
for
step
,
batch
in
enumerate
(
eval_dataloader
):
outputs
=
model
(
**
batch
)
with
torch
.
no_grad
():
outputs
=
model
(
**
batch
)
predictions
=
outputs
.
logits
.
argmax
(
dim
=-
1
)
if
not
is_regression
else
outputs
.
logits
.
squeeze
()
predictions
=
outputs
.
logits
.
argmax
(
dim
=-
1
)
if
not
is_regression
else
outputs
.
logits
.
squeeze
()
predictions
,
references
=
accelerator
.
gather
((
predictions
,
batch
[
"labels"
]))
predictions
,
references
=
accelerator
.
gather
((
predictions
,
batch
[
"labels"
]))
# If we are in a multiprocess environment, the last batch has duplicates
# If we are in a multiprocess environment, the last batch has duplicates
...
...
templates/adding_a_new_example_script/{{cookiecutter.directory_name}}/run_{{cookiecutter.example_shortcut}}.py
View file @
bdd690a7
...
@@ -28,6 +28,7 @@ from dataclasses import dataclass, field
...
@@ -28,6 +28,7 @@ from dataclasses import dataclass, field
from
typing
import
Optional
,
List
from
typing
import
Optional
,
List
import
datasets
import
datasets
import
torch
from
datasets
import
load_dataset
from
datasets
import
load_dataset
import
transformers
import
transformers
...
@@ -871,7 +872,8 @@ def main():
...
@@ -871,7 +872,8 @@ def main():
model
.
eval
()
model
.
eval
()
for
step
,
batch
in
enumerate
(
eval_dataloader
):
for
step
,
batch
in
enumerate
(
eval_dataloader
):
outputs
=
model
(
**
batch
)
with
torch
.
no_grad
():
outputs
=
model
(
**
batch
)
predictions
=
outputs
.
logits
.
argmax
(
dim
=-
1
)
predictions
=
outputs
.
logits
.
argmax
(
dim
=-
1
)
metric
.
add_batch
(
metric
.
add_batch
(
predictions
=
accelerator
.
gather
(
predictions
),
predictions
=
accelerator
.
gather
(
predictions
),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment