"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "0412f3d9298cdb8ba7f69570753ec6a07d240c87"
Unverified Commit 8ddbfe97 authored by Kamal Raj's avatar Kamal Raj Committed by GitHub
Browse files

switch to inference_mode from no_gard (#13667)

* switch to inference_mode from no_gard
faster inference

* added switch to support older version of pytorch
parent ebd48c6d
......@@ -25,6 +25,8 @@ from contextlib import contextmanager
from os.path import abspath, exists
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
from packaging import version
from ..feature_extraction_utils import PreTrainedFeatureExtractor
from ..file_utils import ModelOutput, add_end_docstrings, is_tf_available, is_torch_available
from ..modelcard import ModelCard
......@@ -866,7 +868,12 @@ class Pipeline(_ScikitCompat):
model_inputs["training"] = False
model_outputs = self._forward(model_inputs, **forward_params)
elif self.framework == "pt":
with torch.no_grad():
inference_context = (
torch.inference_mode
if version.parse(torch.__version__) >= version.parse("1.9.0")
else torch.no_grad
)
with inference_context():
model_inputs = self._ensure_tensor_on_device(model_inputs, device=self.device)
model_outputs = self._forward(model_inputs, **forward_params)
model_outputs = self._ensure_tensor_on_device(model_outputs, device=torch.device("cpu"))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment