Unverified Commit bdd690a7 authored by yujun's avatar yujun Committed by GitHub
Browse files

add torch.no_grad when in eval mode (#17020)

* add torch.no_grad when in eval mode

* make style quality
parent 9586e222
...@@ -469,6 +469,7 @@ def main(): ...@@ -469,6 +469,7 @@ def main():
model.eval() model.eval()
samples_seen = 0 samples_seen = 0
for step, batch in enumerate(eval_dataloader): for step, batch in enumerate(eval_dataloader):
with torch.no_grad():
outputs = model(**batch) outputs = model(**batch)
predictions = outputs.logits.argmax(dim=-1) predictions = outputs.logits.argmax(dim=-1)
predictions, references = accelerator.gather((predictions, batch["labels"])) predictions, references = accelerator.gather((predictions, batch["labels"]))
......
...@@ -579,6 +579,7 @@ def main(): ...@@ -579,6 +579,7 @@ def main():
model.eval() model.eval()
samples_seen = 0 samples_seen = 0
for step, batch in enumerate(tqdm(eval_dataloader, disable=not accelerator.is_local_main_process)): for step, batch in enumerate(tqdm(eval_dataloader, disable=not accelerator.is_local_main_process)):
with torch.no_grad():
outputs = model(**batch) outputs = model(**batch)
upsampled_logits = torch.nn.functional.interpolate( upsampled_logits = torch.nn.functional.interpolate(
......
...@@ -22,6 +22,7 @@ import random ...@@ -22,6 +22,7 @@ import random
from pathlib import Path from pathlib import Path
import datasets import datasets
import torch
from datasets import load_dataset, load_metric from datasets import load_dataset, load_metric
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from tqdm.auto import tqdm from tqdm.auto import tqdm
...@@ -514,6 +515,7 @@ def main(): ...@@ -514,6 +515,7 @@ def main():
model.eval() model.eval()
samples_seen = 0 samples_seen = 0
for step, batch in enumerate(eval_dataloader): for step, batch in enumerate(eval_dataloader):
with torch.no_grad():
outputs = model(**batch) outputs = model(**batch)
predictions = outputs.logits.argmax(dim=-1) if not is_regression else outputs.logits.squeeze() predictions = outputs.logits.argmax(dim=-1) if not is_regression else outputs.logits.squeeze()
predictions, references = accelerator.gather((predictions, batch["labels"])) predictions, references = accelerator.gather((predictions, batch["labels"]))
......
...@@ -28,6 +28,7 @@ from dataclasses import dataclass, field ...@@ -28,6 +28,7 @@ from dataclasses import dataclass, field
from typing import Optional, List from typing import Optional, List
import datasets import datasets
import torch
from datasets import load_dataset from datasets import load_dataset
import transformers import transformers
...@@ -871,6 +872,7 @@ def main(): ...@@ -871,6 +872,7 @@ def main():
model.eval() model.eval()
for step, batch in enumerate(eval_dataloader): for step, batch in enumerate(eval_dataloader):
with torch.no_grad():
outputs = model(**batch) outputs = model(**batch)
predictions = outputs.logits.argmax(dim=-1) predictions = outputs.logits.argmax(dim=-1)
metric.add_batch( metric.add_batch(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment