"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "3f30ebe6ca27b2cbad88c890ad5183b54f19db3c"
Unverified Commit 6fe8d198 authored by Wang, Yi's avatar Wang, Yi Committed by GitHub
Browse files

use accelerate autocast in jit eval path, since mix precision logic is… (#24460)



use accelerate autocast in jit eval path, since mix precision logic is in accelerator currently
Signed-off-by: default avatarWang, Yi A <yi.a.wang@intel.com>
parent 0863436b
......@@ -17,6 +17,7 @@ The Trainer class, to easily train a 🤗 Transformers from scratch or finetune
"""
import contextlib
import copy
import functools
import glob
import inspect
......@@ -143,7 +144,6 @@ from .utils import (
logging,
strtobool,
)
from .utils.generic import ContextManagers
DEFAULT_CALLBACKS = [DefaultFlowCallback]
......@@ -1265,9 +1265,14 @@ class Trainer:
example_batch = next(iter(dataloader))
example_batch = self._prepare_inputs(example_batch)
try:
jit_model = model.eval()
with ContextManagers([self.autocast_smart_context_manager(cache_enabled=False), torch.no_grad()]):
if version.parse(version.parse(torch.__version__).base_version) >= version.parse("1.14.0"):
jit_model = copy.copy(model)
jit_model.eval()
original_forward = jit_model.__dict__.pop("_original_forward", None)
# remove mixed precision hooks from the model
if original_forward:
jit_model.forward = original_forward
with self.accelerator.autocast(cache_enabled=False), torch.no_grad():
if version.parse(version.parse(torch.__version__).base_version) >= version.parse("2.0.0"):
if isinstance(example_batch, dict):
jit_model = torch.jit.trace(jit_model, example_kwarg_inputs=example_batch, strict=False)
else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment