Unverified Commit b3385963 authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

Fixing image segmentation with inference mode. (#14204)



* Fixing image segmentation for inference mode.

* Update src/transformers/pipelines/base.py
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent c28bc80b
...@@ -1011,17 +1011,19 @@ class Pipeline(_ScikitCompat): ...@@ -1011,17 +1011,19 @@ class Pipeline(_ScikitCompat):
""" """
raise NotImplementedError("postprocess not implemented") raise NotImplementedError("postprocess not implemented")
def get_inference_context(self):
inference_context = (
torch.inference_mode if version.parse(torch.__version__) >= version.parse("1.9.0") else torch.no_grad
)
return inference_context
def forward(self, model_inputs, **forward_params): def forward(self, model_inputs, **forward_params):
with self.device_placement(): with self.device_placement():
if self.framework == "tf": if self.framework == "tf":
model_inputs["training"] = False model_inputs["training"] = False
model_outputs = self._forward(model_inputs, **forward_params) model_outputs = self._forward(model_inputs, **forward_params)
elif self.framework == "pt": elif self.framework == "pt":
inference_context = ( inference_context = self.get_inference_context()
torch.inference_mode
if version.parse(torch.__version__) >= version.parse("1.9.0")
else torch.no_grad
)
with inference_context(): with inference_context():
model_inputs = self._ensure_tensor_on_device(model_inputs, device=self.device) model_inputs = self._ensure_tensor_on_device(model_inputs, device=self.device)
model_outputs = self._forward(model_inputs, **forward_params) model_outputs = self._forward(model_inputs, **forward_params)
......
...@@ -114,6 +114,9 @@ class ImageSegmentationPipeline(Pipeline): ...@@ -114,6 +114,9 @@ class ImageSegmentationPipeline(Pipeline):
return super().__call__(*args, **kwargs) return super().__call__(*args, **kwargs)
def get_inference_context(self):
return torch.no_grad
def preprocess(self, image): def preprocess(self, image):
image = self.load_image(image) image = self.load_image(image)
target_size = torch.IntTensor([[image.height, image.width]]) target_size = torch.IntTensor([[image.height, image.width]])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment