Unverified Commit 6fe8d198 authored by Wang, Yi's avatar Wang, Yi Committed by GitHub
Browse files

use accelerate autocast in jit eval path, since mix precision logic is… (#24460)



use accelerate autocast in jit eval path, since mix precision logic is in accelerator currently
Signed-off-by: default avatarWang, Yi A <yi.a.wang@intel.com>
parent 0863436b
...@@ -17,6 +17,7 @@ The Trainer class, to easily train a 🤗 Transformers from scratch or finetune ...@@ -17,6 +17,7 @@ The Trainer class, to easily train a 🤗 Transformers from scratch or finetune
""" """
import contextlib import contextlib
import copy
import functools import functools
import glob import glob
import inspect import inspect
...@@ -143,7 +144,6 @@ from .utils import ( ...@@ -143,7 +144,6 @@ from .utils import (
logging, logging,
strtobool, strtobool,
) )
from .utils.generic import ContextManagers
DEFAULT_CALLBACKS = [DefaultFlowCallback] DEFAULT_CALLBACKS = [DefaultFlowCallback]
...@@ -1265,9 +1265,14 @@ class Trainer: ...@@ -1265,9 +1265,14 @@ class Trainer:
example_batch = next(iter(dataloader)) example_batch = next(iter(dataloader))
example_batch = self._prepare_inputs(example_batch) example_batch = self._prepare_inputs(example_batch)
try: try:
jit_model = model.eval() jit_model = copy.copy(model)
with ContextManagers([self.autocast_smart_context_manager(cache_enabled=False), torch.no_grad()]): jit_model.eval()
if version.parse(version.parse(torch.__version__).base_version) >= version.parse("1.14.0"): original_forward = jit_model.__dict__.pop("_original_forward", None)
# remove mixed precision hooks from the model
if original_forward:
jit_model.forward = original_forward
with self.accelerator.autocast(cache_enabled=False), torch.no_grad():
if version.parse(version.parse(torch.__version__).base_version) >= version.parse("2.0.0"):
if isinstance(example_batch, dict): if isinstance(example_batch, dict):
jit_model = torch.jit.trace(jit_model, example_kwarg_inputs=example_batch, strict=False) jit_model = torch.jit.trace(jit_model, example_kwarg_inputs=example_batch, strict=False)
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment