Unverified Commit ac262604 authored by Billy Cao's avatar Billy Cao Committed by GitHub
Browse files

Allow FP16 or other precision inference for Pipelines (#31342)



* cast image features to model.dtype where needed to support FP16 or other precision in pipelines

* Update src/transformers/pipelines/image_feature_extraction.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Use .to instead

* Add FP16 pipeline support for zeroshot audio classification

* Remove unused torch imports

* Add docs on FP16 pipeline

* Remove unused import

* Add FP16 tests to pipeline mixin

* Add fp16 placeholder for mask_generation pipeline test

* Add FP16 tests for all pipelines

* Fix formatting

* Remove torch_dtype arg from is_pipeline_test_to_skip*

* Fix format

* trigger ci

---------
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>
parent e7868444
...@@ -270,6 +270,11 @@ This is a simplified view, since the pipeline can handle automatically the batch ...@@ -270,6 +270,11 @@ This is a simplified view, since the pipeline can handle automatically the batch
about how many forward passes you inputs are actually going to trigger, you can optimize the `batch_size` about how many forward passes you inputs are actually going to trigger, you can optimize the `batch_size`
independently of the inputs. The caveats from the previous section still apply. independently of the inputs. The caveats from the previous section still apply.
## Pipeline FP16 inference
Models can be run in FP16 which can be significantly faster on GPU while saving memory. Most models will not suffer noticeable performance loss from this. The larger the model, the less likely that it will.
To enable FP16 inference, you can simply pass `torch_dtype=torch.float16` or `torch_dtype='float16'` to the pipeline constructor. Note that this only works for models with a PyTorch backend. Your inputs will be converted to FP16 internally.
## Pipeline custom code ## Pipeline custom code
If you want to override a specific pipeline. If you want to override a specific pipeline.
......
...@@ -113,7 +113,9 @@ This will work regardless of whether you are using PyTorch or Tensorflow. ...@@ -113,7 +113,9 @@ This will work regardless of whether you are using PyTorch or Tensorflow.
transcriber = pipeline(model="openai/whisper-large-v2", device=0) transcriber = pipeline(model="openai/whisper-large-v2", device=0)
``` ```
If the model is too large for a single GPU and you are using PyTorch, you can set `device_map="auto"` to automatically If the model is too large for a single GPU and you are using PyTorch, you can set `torch_dtype='float16'` to enable FP16 precision inference. Usually this would not cause significant performance drops but make sure you evaluate it on your models!
Alternatively, you can set `device_map="auto"` to automatically
determine how to load and store the model weights. Using the `device_map` argument requires the 🤗 [Accelerate](https://huggingface.co/docs/accelerate) determine how to load and store the model weights. Using the `device_map` argument requires the 🤗 [Accelerate](https://huggingface.co/docs/accelerate)
package: package:
...@@ -342,4 +344,3 @@ gr.Interface.from_pipeline(pipe).launch() ...@@ -342,4 +344,3 @@ gr.Interface.from_pipeline(pipe).launch()
By default, the web demo runs on a local server. If you'd like to share it with others, you can generate a temporary public By default, the web demo runs on a local server. If you'd like to share it with others, you can generate a temporary public
link by setting `share=True` in `launch()`. You can also host your demo on [Hugging Face Spaces](https://huggingface.co/spaces) for a permanent link. link by setting `share=True` in `launch()`. You can also host your demo on [Hugging Face Spaces](https://huggingface.co/spaces) for a permanent link.
...@@ -91,6 +91,8 @@ class DepthEstimationPipeline(Pipeline): ...@@ -91,6 +91,8 @@ class DepthEstimationPipeline(Pipeline):
image = load_image(image, timeout) image = load_image(image, timeout)
self.image_size = image.size self.image_size = image.size
model_inputs = self.image_processor(images=image, return_tensors=self.framework) model_inputs = self.image_processor(images=image, return_tensors=self.framework)
if self.framework == "pt":
model_inputs = model_inputs.to(self.torch_dtype)
return model_inputs return model_inputs
def _forward(self, model_inputs): def _forward(self, model_inputs):
......
...@@ -294,7 +294,10 @@ class DocumentQuestionAnsweringPipeline(ChunkPipeline): ...@@ -294,7 +294,10 @@ class DocumentQuestionAnsweringPipeline(ChunkPipeline):
if input.get("image", None) is not None: if input.get("image", None) is not None:
image = load_image(input["image"], timeout=timeout) image = load_image(input["image"], timeout=timeout)
if self.image_processor is not None: if self.image_processor is not None:
image_features.update(self.image_processor(images=image, return_tensors=self.framework)) image_inputs = self.image_processor(images=image, return_tensors=self.framework)
if self.framework == "pt":
image_inputs = image_inputs.to(self.torch_dtype)
image_features.update(image_inputs)
elif self.feature_extractor is not None: elif self.feature_extractor is not None:
image_features.update(self.feature_extractor(images=image, return_tensors=self.framework)) image_features.update(self.feature_extractor(images=image, return_tensors=self.framework))
elif self.model_type == ModelType.VisionEncoderDecoder: elif self.model_type == ModelType.VisionEncoderDecoder:
......
...@@ -161,6 +161,8 @@ class ImageClassificationPipeline(Pipeline): ...@@ -161,6 +161,8 @@ class ImageClassificationPipeline(Pipeline):
def preprocess(self, image, timeout=None): def preprocess(self, image, timeout=None):
image = load_image(image, timeout=timeout) image = load_image(image, timeout=timeout)
model_inputs = self.image_processor(images=image, return_tensors=self.framework) model_inputs = self.image_processor(images=image, return_tensors=self.framework)
if self.framework == "pt":
model_inputs = model_inputs.to(self.torch_dtype)
return model_inputs return model_inputs
def _forward(self, model_inputs): def _forward(self, model_inputs):
......
...@@ -60,6 +60,8 @@ class ImageFeatureExtractionPipeline(Pipeline): ...@@ -60,6 +60,8 @@ class ImageFeatureExtractionPipeline(Pipeline):
def preprocess(self, image, timeout=None, **image_processor_kwargs) -> Dict[str, GenericTensor]: def preprocess(self, image, timeout=None, **image_processor_kwargs) -> Dict[str, GenericTensor]:
image = load_image(image, timeout=timeout) image = load_image(image, timeout=timeout)
model_inputs = self.image_processor(image, return_tensors=self.framework, **image_processor_kwargs) model_inputs = self.image_processor(image, return_tensors=self.framework, **image_processor_kwargs)
if self.framework == "pt":
model_inputs = model_inputs.to(self.torch_dtype)
return model_inputs return model_inputs
def _forward(self, model_inputs): def _forward(self, model_inputs):
......
...@@ -147,6 +147,8 @@ class ImageSegmentationPipeline(Pipeline): ...@@ -147,6 +147,8 @@ class ImageSegmentationPipeline(Pipeline):
else: else:
kwargs = {"task_inputs": [subtask]} kwargs = {"task_inputs": [subtask]}
inputs = self.image_processor(images=[image], return_tensors="pt", **kwargs) inputs = self.image_processor(images=[image], return_tensors="pt", **kwargs)
if self.framework == "pt":
inputs = inputs.to(self.torch_dtype)
inputs["task_inputs"] = self.tokenizer( inputs["task_inputs"] = self.tokenizer(
inputs["task_inputs"], inputs["task_inputs"],
padding="max_length", padding="max_length",
...@@ -155,6 +157,8 @@ class ImageSegmentationPipeline(Pipeline): ...@@ -155,6 +157,8 @@ class ImageSegmentationPipeline(Pipeline):
)["input_ids"] )["input_ids"]
else: else:
inputs = self.image_processor(images=[image], return_tensors="pt") inputs = self.image_processor(images=[image], return_tensors="pt")
if self.framework == "pt":
inputs = inputs.to(self.torch_dtype)
inputs["target_size"] = target_size inputs["target_size"] = target_size
return inputs return inputs
......
...@@ -119,6 +119,8 @@ class ImageToImagePipeline(Pipeline): ...@@ -119,6 +119,8 @@ class ImageToImagePipeline(Pipeline):
def preprocess(self, image, timeout=None): def preprocess(self, image, timeout=None):
image = load_image(image, timeout=timeout) image = load_image(image, timeout=timeout)
inputs = self.image_processor(images=[image], return_tensors="pt") inputs = self.image_processor(images=[image], return_tensors="pt")
if self.framework == "pt":
inputs = inputs.to(self.torch_dtype)
return inputs return inputs
def postprocess(self, model_outputs): def postprocess(self, model_outputs):
......
...@@ -138,6 +138,8 @@ class ImageToTextPipeline(Pipeline): ...@@ -138,6 +138,8 @@ class ImageToTextPipeline(Pipeline):
if model_type == "git": if model_type == "git":
model_inputs = self.image_processor(images=image, return_tensors=self.framework) model_inputs = self.image_processor(images=image, return_tensors=self.framework)
if self.framework == "pt":
model_inputs = model_inputs.to(self.torch_dtype)
input_ids = self.tokenizer(text=prompt, add_special_tokens=False).input_ids input_ids = self.tokenizer(text=prompt, add_special_tokens=False).input_ids
input_ids = [self.tokenizer.cls_token_id] + input_ids input_ids = [self.tokenizer.cls_token_id] + input_ids
input_ids = torch.tensor(input_ids).unsqueeze(0) input_ids = torch.tensor(input_ids).unsqueeze(0)
...@@ -145,10 +147,14 @@ class ImageToTextPipeline(Pipeline): ...@@ -145,10 +147,14 @@ class ImageToTextPipeline(Pipeline):
elif model_type == "pix2struct": elif model_type == "pix2struct":
model_inputs = self.image_processor(images=image, header_text=prompt, return_tensors=self.framework) model_inputs = self.image_processor(images=image, header_text=prompt, return_tensors=self.framework)
if self.framework == "pt":
model_inputs = model_inputs.to(self.torch_dtype)
elif model_type != "vision-encoder-decoder": elif model_type != "vision-encoder-decoder":
# vision-encoder-decoder does not support conditional generation # vision-encoder-decoder does not support conditional generation
model_inputs = self.image_processor(images=image, return_tensors=self.framework) model_inputs = self.image_processor(images=image, return_tensors=self.framework)
if self.framework == "pt":
model_inputs = model_inputs.to(self.torch_dtype)
text_inputs = self.tokenizer(prompt, return_tensors=self.framework) text_inputs = self.tokenizer(prompt, return_tensors=self.framework)
model_inputs.update(text_inputs) model_inputs.update(text_inputs)
...@@ -157,6 +163,8 @@ class ImageToTextPipeline(Pipeline): ...@@ -157,6 +163,8 @@ class ImageToTextPipeline(Pipeline):
else: else:
model_inputs = self.image_processor(images=image, return_tensors=self.framework) model_inputs = self.image_processor(images=image, return_tensors=self.framework)
if self.framework == "pt":
model_inputs = model_inputs.to(self.torch_dtype)
if self.model.config.model_type == "git" and prompt is None: if self.model.config.model_type == "git" and prompt is None:
model_inputs["input_ids"] = None model_inputs["input_ids"] = None
......
...@@ -181,6 +181,8 @@ class MaskGenerationPipeline(ChunkPipeline): ...@@ -181,6 +181,8 @@ class MaskGenerationPipeline(ChunkPipeline):
image, target_size, crops_n_layers, crop_overlap_ratio, points_per_crop, crop_n_points_downscale_factor image, target_size, crops_n_layers, crop_overlap_ratio, points_per_crop, crop_n_points_downscale_factor
) )
model_inputs = self.image_processor(images=cropped_images, return_tensors="pt") model_inputs = self.image_processor(images=cropped_images, return_tensors="pt")
if self.framework == "pt":
model_inputs = model_inputs.to(self.torch_dtype)
with self.device_placement(): with self.device_placement():
if self.framework == "pt": if self.framework == "pt":
......
...@@ -107,6 +107,8 @@ class ObjectDetectionPipeline(Pipeline): ...@@ -107,6 +107,8 @@ class ObjectDetectionPipeline(Pipeline):
image = load_image(image, timeout=timeout) image = load_image(image, timeout=timeout)
target_size = torch.IntTensor([[image.height, image.width]]) target_size = torch.IntTensor([[image.height, image.width]])
inputs = self.image_processor(images=[image], return_tensors="pt") inputs = self.image_processor(images=[image], return_tensors="pt")
if self.framework == "pt":
inputs = inputs.to(self.torch_dtype)
if self.tokenizer is not None: if self.tokenizer is not None:
inputs = self.tokenizer(text=inputs["words"], boxes=inputs["boxes"], return_tensors="pt") inputs = self.tokenizer(text=inputs["words"], boxes=inputs["boxes"], return_tensors="pt")
inputs["target_size"] = target_size inputs["target_size"] = target_size
......
...@@ -106,6 +106,8 @@ class VideoClassificationPipeline(Pipeline): ...@@ -106,6 +106,8 @@ class VideoClassificationPipeline(Pipeline):
video = list(video) video = list(video)
model_inputs = self.image_processor(video, return_tensors=self.framework) model_inputs = self.image_processor(video, return_tensors=self.framework)
if self.framework == "pt":
model_inputs = model_inputs.to(self.torch_dtype)
return model_inputs return model_inputs
def _forward(self, model_inputs): def _forward(self, model_inputs):
......
...@@ -155,6 +155,8 @@ class VisualQuestionAnsweringPipeline(Pipeline): ...@@ -155,6 +155,8 @@ class VisualQuestionAnsweringPipeline(Pipeline):
truncation=truncation, truncation=truncation,
) )
image_features = self.image_processor(images=image, return_tensors=self.framework) image_features = self.image_processor(images=image, return_tensors=self.framework)
if self.framework == "pt":
image_features = image_features.to(self.torch_dtype)
model_inputs.update(image_features) model_inputs.update(image_features)
return model_inputs return model_inputs
......
...@@ -121,6 +121,8 @@ class ZeroShotAudioClassificationPipeline(Pipeline): ...@@ -121,6 +121,8 @@ class ZeroShotAudioClassificationPipeline(Pipeline):
inputs = self.feature_extractor( inputs = self.feature_extractor(
[audio], sampling_rate=self.feature_extractor.sampling_rate, return_tensors="pt" [audio], sampling_rate=self.feature_extractor.sampling_rate, return_tensors="pt"
) )
if self.framework == "pt":
inputs = inputs.to(self.torch_dtype)
inputs["candidate_labels"] = candidate_labels inputs["candidate_labels"] = candidate_labels
sequences = [hypothesis_template.format(x) for x in candidate_labels] sequences = [hypothesis_template.format(x) for x in candidate_labels]
text_inputs = self.tokenizer(sequences, return_tensors=self.framework, padding=True) text_inputs = self.tokenizer(sequences, return_tensors=self.framework, padding=True)
......
...@@ -120,6 +120,8 @@ class ZeroShotImageClassificationPipeline(Pipeline): ...@@ -120,6 +120,8 @@ class ZeroShotImageClassificationPipeline(Pipeline):
def preprocess(self, image, candidate_labels=None, hypothesis_template="This is a photo of {}.", timeout=None): def preprocess(self, image, candidate_labels=None, hypothesis_template="This is a photo of {}.", timeout=None):
image = load_image(image, timeout=timeout) image = load_image(image, timeout=timeout)
inputs = self.image_processor(images=[image], return_tensors=self.framework) inputs = self.image_processor(images=[image], return_tensors=self.framework)
if self.framework == "pt":
inputs = inputs.to(self.torch_dtype)
inputs["candidate_labels"] = candidate_labels inputs["candidate_labels"] = candidate_labels
sequences = [hypothesis_template.format(x) for x in candidate_labels] sequences = [hypothesis_template.format(x) for x in candidate_labels]
padding = "max_length" if self.model.config.model_type == "siglip" else True padding = "max_length" if self.model.config.model_type == "siglip" else True
......
...@@ -156,6 +156,8 @@ class ZeroShotObjectDetectionPipeline(ChunkPipeline): ...@@ -156,6 +156,8 @@ class ZeroShotObjectDetectionPipeline(ChunkPipeline):
for i, candidate_label in enumerate(candidate_labels): for i, candidate_label in enumerate(candidate_labels):
text_inputs = self.tokenizer(candidate_label, return_tensors=self.framework) text_inputs = self.tokenizer(candidate_label, return_tensors=self.framework)
image_features = self.image_processor(image, return_tensors=self.framework) image_features = self.image_processor(image, return_tensors=self.framework)
if self.framework == "pt":
image_features = image_features.to(self.torch_dtype)
yield { yield {
"is_last": i == len(candidate_labels) - 1, "is_last": i == len(candidate_labels) - 1,
"target_size": target_size, "target_size": target_size,
......
...@@ -35,8 +35,10 @@ class AudioClassificationPipelineTests(unittest.TestCase): ...@@ -35,8 +35,10 @@ class AudioClassificationPipelineTests(unittest.TestCase):
model_mapping = MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING model_mapping = MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING
tf_model_mapping = TF_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING tf_model_mapping = TF_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING
def get_test_pipeline(self, model, tokenizer, processor): def get_test_pipeline(self, model, tokenizer, processor, torch_dtype="float32"):
audio_classifier = AudioClassificationPipeline(model=model, feature_extractor=processor) audio_classifier = AudioClassificationPipeline(
model=model, feature_extractor=processor, torch_dtype=torch_dtype
)
# test with a raw waveform # test with a raw waveform
audio = np.zeros((34000,)) audio = np.zeros((34000,))
......
...@@ -66,14 +66,14 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase): ...@@ -66,14 +66,14 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
+ (MODEL_FOR_CTC_MAPPING.items() if MODEL_FOR_CTC_MAPPING else []) + (MODEL_FOR_CTC_MAPPING.items() if MODEL_FOR_CTC_MAPPING else [])
) )
def get_test_pipeline(self, model, tokenizer, processor): def get_test_pipeline(self, model, tokenizer, processor, torch_dtype="float32"):
if tokenizer is None: if tokenizer is None:
# Side effect of no Fast Tokenizer class for these model, so skipping # Side effect of no Fast Tokenizer class for these model, so skipping
# But the slow tokenizer test should still run as they're quite small # But the slow tokenizer test should still run as they're quite small
self.skipTest(reason="No tokenizer available") self.skipTest(reason="No tokenizer available")
speech_recognizer = AutomaticSpeechRecognitionPipeline( speech_recognizer = AutomaticSpeechRecognitionPipeline(
model=model, tokenizer=tokenizer, feature_extractor=processor model=model, tokenizer=tokenizer, feature_extractor=processor, torch_dtype=torch_dtype
) )
# test with a raw waveform # test with a raw waveform
......
...@@ -56,8 +56,8 @@ def hashimage(image: Image) -> str: ...@@ -56,8 +56,8 @@ def hashimage(image: Image) -> str:
class DepthEstimationPipelineTests(unittest.TestCase): class DepthEstimationPipelineTests(unittest.TestCase):
model_mapping = MODEL_FOR_DEPTH_ESTIMATION_MAPPING model_mapping = MODEL_FOR_DEPTH_ESTIMATION_MAPPING
def get_test_pipeline(self, model, tokenizer, processor): def get_test_pipeline(self, model, tokenizer, processor, torch_dtype="float32"):
depth_estimator = DepthEstimationPipeline(model=model, image_processor=processor) depth_estimator = DepthEstimationPipeline(model=model, image_processor=processor, torch_dtype=torch_dtype)
return depth_estimator, [ return depth_estimator, [
"./tests/fixtures/tests_samples/COCO/000000039769.png", "./tests/fixtures/tests_samples/COCO/000000039769.png",
"./tests/fixtures/tests_samples/COCO/000000039769.png", "./tests/fixtures/tests_samples/COCO/000000039769.png",
......
...@@ -61,9 +61,13 @@ class DocumentQuestionAnsweringPipelineTests(unittest.TestCase): ...@@ -61,9 +61,13 @@ class DocumentQuestionAnsweringPipelineTests(unittest.TestCase):
@require_pytesseract @require_pytesseract
@require_vision @require_vision
def get_test_pipeline(self, model, tokenizer, processor): def get_test_pipeline(self, model, tokenizer, processor, torch_dtype="float32"):
dqa_pipeline = pipeline( dqa_pipeline = pipeline(
"document-question-answering", model=model, tokenizer=tokenizer, image_processor=processor "document-question-answering",
model=model,
tokenizer=tokenizer,
image_processor=processor,
torch_dtype=torch_dtype,
) )
image = INVOICE_URL image = INVOICE_URL
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment