Unverified Commit ba3264b4 authored by amyeroberts's avatar amyeroberts Committed by GitHub
Browse files

Image Feature Extraction pipeline (#28216)



* Draft pipeline

* Fixup

* Fix docstrings

* Update doctest

* Update pipeline_model_mapping

* Update docstring

* Update tests

* Update src/transformers/pipelines/image_feature_extraction.py
Co-authored-by: default avatarOmar Sanseviero <osanseviero@gmail.com>

* Fix docstrings - review comments

* Remove pipeline mapping for composite vision models

* Add to pipeline tests

* Remove for flava (multimodal)

* safe pil import

* Add requirements for pipeline run

* Account for super slow efficientnet

* Review comments

* Fix tests

* Swap order of kwargs

* Use build_pipeline_init_args

* Add back FE pipeline for Vilt

* Include image_processor_kwargs in docstring

* Mark test as flaky

* Update TODO

* Update tests/pipelines/test_pipelines_image_feature_extraction.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Add license header

---------
Co-authored-by: default avatarOmar Sanseviero <osanseviero@gmail.com>
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>
parent 7addc934
...@@ -182,7 +182,7 @@ class DetrModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin ...@@ -182,7 +182,7 @@ class DetrModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
) )
pipeline_model_mapping = ( pipeline_model_mapping = (
{ {
"feature-extraction": DetrModel, "image-feature-extraction": DetrModel,
"image-segmentation": DetrForSegmentation, "image-segmentation": DetrForSegmentation,
"object-detection": DetrForObjectDetection, "object-detection": DetrForObjectDetection,
} }
......
...@@ -207,7 +207,7 @@ class DinatModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): ...@@ -207,7 +207,7 @@ class DinatModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
else () else ()
) )
pipeline_model_mapping = ( pipeline_model_mapping = (
{"feature-extraction": DinatModel, "image-classification": DinatForImageClassification} {"image-feature-extraction": DinatModel, "image-classification": DinatForImageClassification}
if is_torch_available() if is_torch_available()
else {} else {}
) )
......
...@@ -217,7 +217,7 @@ class Dinov2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): ...@@ -217,7 +217,7 @@ class Dinov2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
else () else ()
) )
pipeline_model_mapping = ( pipeline_model_mapping = (
{"feature-extraction": Dinov2Model, "image-classification": Dinov2ForImageClassification} {"image-feature-extraction": Dinov2Model, "image-classification": Dinov2ForImageClassification}
if is_torch_available() if is_torch_available()
else {} else {}
) )
......
...@@ -145,7 +145,7 @@ class DonutSwinModelTester: ...@@ -145,7 +145,7 @@ class DonutSwinModelTester:
@require_torch @require_torch
class DonutSwinModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): class DonutSwinModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (DonutSwinModel,) if is_torch_available() else () all_model_classes = (DonutSwinModel,) if is_torch_available() else ()
pipeline_model_mapping = {"feature-extraction": DonutSwinModel} if is_torch_available() else {} pipeline_model_mapping = {"image-feature-extraction": DonutSwinModel} if is_torch_available() else {}
fx_compatible = True fx_compatible = True
test_pruning = False test_pruning = False
......
...@@ -163,7 +163,7 @@ class DPTModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): ...@@ -163,7 +163,7 @@ class DPTModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
pipeline_model_mapping = ( pipeline_model_mapping = (
{ {
"depth-estimation": DPTForDepthEstimation, "depth-estimation": DPTForDepthEstimation,
"feature-extraction": DPTModel, "image-feature-extraction": DPTModel,
"image-segmentation": DPTForSemanticSegmentation, "image-segmentation": DPTForSemanticSegmentation,
} }
if is_torch_available() if is_torch_available()
......
...@@ -190,7 +190,7 @@ class EfficientFormerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.T ...@@ -190,7 +190,7 @@ class EfficientFormerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.T
) )
pipeline_model_mapping = ( pipeline_model_mapping = (
{ {
"feature-extraction": EfficientFormerModel, "image-feature-extraction": EfficientFormerModel,
"image-classification": ( "image-classification": (
EfficientFormerForImageClassification, EfficientFormerForImageClassification,
EfficientFormerForImageClassificationWithTeacher, EfficientFormerForImageClassificationWithTeacher,
......
...@@ -130,7 +130,7 @@ class EfficientNetModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.Test ...@@ -130,7 +130,7 @@ class EfficientNetModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.Test
all_model_classes = (EfficientNetModel, EfficientNetForImageClassification) if is_torch_available() else () all_model_classes = (EfficientNetModel, EfficientNetForImageClassification) if is_torch_available() else ()
pipeline_model_mapping = ( pipeline_model_mapping = (
{"feature-extraction": EfficientNetModel, "image-classification": EfficientNetForImageClassification} {"image-feature-extraction": EfficientNetModel, "image-classification": EfficientNetForImageClassification}
if is_torch_available() if is_torch_available()
else {} else {}
) )
...@@ -216,6 +216,12 @@ class EfficientNetModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.Test ...@@ -216,6 +216,12 @@ class EfficientNetModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.Test
model = EfficientNetModel.from_pretrained(model_name) model = EfficientNetModel.from_pretrained(model_name)
self.assertIsNotNone(model) self.assertIsNotNone(model)
@is_pipeline_test
@require_vision
@slow
def test_pipeline_image_feature_extraction(self):
super().test_pipeline_image_feature_extraction()
@is_pipeline_test @is_pipeline_test
@require_vision @require_vision
@slow @slow
......
...@@ -238,7 +238,7 @@ class FocalNetModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase ...@@ -238,7 +238,7 @@ class FocalNetModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase
else () else ()
) )
pipeline_model_mapping = ( pipeline_model_mapping = (
{"feature-extraction": FocalNetModel, "image-classification": FocalNetForImageClassification} {"image-feature-extraction": FocalNetModel, "image-classification": FocalNetForImageClassification}
if is_torch_available() if is_torch_available()
else {} else {}
) )
......
...@@ -146,7 +146,9 @@ class GLPNModelTester: ...@@ -146,7 +146,9 @@ class GLPNModelTester:
class GLPNModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): class GLPNModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (GLPNModel, GLPNForDepthEstimation) if is_torch_available() else () all_model_classes = (GLPNModel, GLPNForDepthEstimation) if is_torch_available() else ()
pipeline_model_mapping = ( pipeline_model_mapping = (
{"depth-estimation": GLPNForDepthEstimation, "feature-extraction": GLPNModel} if is_torch_available() else {} {"depth-estimation": GLPNForDepthEstimation, "image-feature-extraction": GLPNModel}
if is_torch_available()
else {}
) )
test_head_masking = False test_head_masking = False
......
...@@ -271,7 +271,7 @@ class ImageGPTModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM ...@@ -271,7 +271,7 @@ class ImageGPTModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM
) )
all_generative_model_classes = (ImageGPTForCausalImageModeling,) if is_torch_available() else () all_generative_model_classes = (ImageGPTForCausalImageModeling,) if is_torch_available() else ()
pipeline_model_mapping = ( pipeline_model_mapping = (
{"feature-extraction": ImageGPTModel, "image-classification": ImageGPTForImageClassification} {"image-feature-extraction": ImageGPTModel, "image-classification": ImageGPTForImageClassification}
if is_torch_available() if is_torch_available()
else {} else {}
) )
......
...@@ -176,7 +176,7 @@ class LevitModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): ...@@ -176,7 +176,7 @@ class LevitModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
) )
pipeline_model_mapping = ( pipeline_model_mapping = (
{ {
"feature-extraction": LevitModel, "image-feature-extraction": LevitModel,
"image-classification": (LevitForImageClassification, LevitForImageClassificationWithTeacher), "image-classification": (LevitForImageClassification, LevitForImageClassificationWithTeacher),
} }
if is_torch_available() if is_torch_available()
......
...@@ -197,7 +197,7 @@ class Mask2FormerModelTester: ...@@ -197,7 +197,7 @@ class Mask2FormerModelTester:
@require_torch @require_torch
class Mask2FormerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): class Mask2FormerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (Mask2FormerModel, Mask2FormerForUniversalSegmentation) if is_torch_available() else () all_model_classes = (Mask2FormerModel, Mask2FormerForUniversalSegmentation) if is_torch_available() else ()
pipeline_model_mapping = {"feature-extraction": Mask2FormerModel} if is_torch_available() else {} pipeline_model_mapping = {"image-feature-extraction": Mask2FormerModel} if is_torch_available() else {}
is_encoder_decoder = False is_encoder_decoder = False
test_pruning = False test_pruning = False
......
...@@ -197,7 +197,7 @@ class MaskFormerModelTester: ...@@ -197,7 +197,7 @@ class MaskFormerModelTester:
class MaskFormerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): class MaskFormerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (MaskFormerModel, MaskFormerForInstanceSegmentation) if is_torch_available() else () all_model_classes = (MaskFormerModel, MaskFormerForInstanceSegmentation) if is_torch_available() else ()
pipeline_model_mapping = ( pipeline_model_mapping = (
{"feature-extraction": MaskFormerModel, "image-segmentation": MaskFormerForInstanceSegmentation} {"image-feature-extraction": MaskFormerModel, "image-segmentation": MaskFormerForInstanceSegmentation}
if is_torch_available() if is_torch_available()
else {} else {}
) )
......
...@@ -31,7 +31,7 @@ if is_torch_available(): ...@@ -31,7 +31,7 @@ if is_torch_available():
import torch import torch
from torch import nn from torch import nn
from transformers import MgpstrForSceneTextRecognition from transformers import MgpstrForSceneTextRecognition, MgpstrModel
if is_vision_available(): if is_vision_available():
...@@ -118,7 +118,11 @@ class MgpstrModelTester: ...@@ -118,7 +118,11 @@ class MgpstrModelTester:
@require_torch @require_torch
class MgpstrModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): class MgpstrModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (MgpstrForSceneTextRecognition,) if is_torch_available() else () all_model_classes = (MgpstrForSceneTextRecognition,) if is_torch_available() else ()
pipeline_model_mapping = {"feature-extraction": MgpstrForSceneTextRecognition} if is_torch_available() else {} pipeline_model_mapping = (
{"feature-extraction": MgpstrForSceneTextRecognition, "image-feature-extraction": MgpstrModel}
if is_torch_available()
else {}
)
fx_compatible = False fx_compatible = False
test_pruning = False test_pruning = False
......
...@@ -147,7 +147,7 @@ class MobileNetV1ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestC ...@@ -147,7 +147,7 @@ class MobileNetV1ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestC
all_model_classes = (MobileNetV1Model, MobileNetV1ForImageClassification) if is_torch_available() else () all_model_classes = (MobileNetV1Model, MobileNetV1ForImageClassification) if is_torch_available() else ()
pipeline_model_mapping = ( pipeline_model_mapping = (
{"feature-extraction": MobileNetV1Model, "image-classification": MobileNetV1ForImageClassification} {"image-feature-extraction": MobileNetV1Model, "image-classification": MobileNetV1ForImageClassification}
if is_torch_available() if is_torch_available()
else {} else {}
) )
......
...@@ -195,7 +195,7 @@ class MobileNetV2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestC ...@@ -195,7 +195,7 @@ class MobileNetV2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestC
) )
pipeline_model_mapping = ( pipeline_model_mapping = (
{ {
"feature-extraction": MobileNetV2Model, "image-feature-extraction": MobileNetV2Model,
"image-classification": MobileNetV2ForImageClassification, "image-classification": MobileNetV2ForImageClassification,
"image-segmentation": MobileNetV2ForSemanticSegmentation, "image-segmentation": MobileNetV2ForSemanticSegmentation,
} }
......
...@@ -188,7 +188,7 @@ class MobileViTModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCas ...@@ -188,7 +188,7 @@ class MobileViTModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCas
) )
pipeline_model_mapping = ( pipeline_model_mapping = (
{ {
"feature-extraction": MobileViTModel, "image-feature-extraction": MobileViTModel,
"image-classification": MobileViTForImageClassification, "image-classification": MobileViTForImageClassification,
"image-segmentation": MobileViTForSemanticSegmentation, "image-segmentation": MobileViTForSemanticSegmentation,
} }
......
...@@ -190,7 +190,7 @@ class MobileViTV2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestC ...@@ -190,7 +190,7 @@ class MobileViTV2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestC
pipeline_model_mapping = ( pipeline_model_mapping = (
{ {
"feature-extraction": MobileViTV2Model, "image-feature-extraction": MobileViTV2Model,
"image-classification": MobileViTV2ForImageClassification, "image-classification": MobileViTV2ForImageClassification,
"image-segmentation": MobileViTV2ForSemanticSegmentation, "image-segmentation": MobileViTV2ForSemanticSegmentation,
} }
......
...@@ -204,7 +204,7 @@ class NatModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): ...@@ -204,7 +204,7 @@ class NatModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
else () else ()
) )
pipeline_model_mapping = ( pipeline_model_mapping = (
{"feature-extraction": NatModel, "image-classification": NatForImageClassification} {"image-feature-extraction": NatModel, "image-classification": NatForImageClassification}
if is_torch_available() if is_torch_available()
else {} else {}
) )
......
...@@ -433,7 +433,10 @@ class Owlv2ModelTester: ...@@ -433,7 +433,10 @@ class Owlv2ModelTester:
class Owlv2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): class Owlv2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (Owlv2Model,) if is_torch_available() else () all_model_classes = (Owlv2Model,) if is_torch_available() else ()
pipeline_model_mapping = ( pipeline_model_mapping = (
{"feature-extraction": Owlv2Model, "zero-shot-object-detection": Owlv2ForObjectDetection} {
"feature-extraction": Owlv2Model,
"zero-shot-object-detection": Owlv2ForObjectDetection,
}
if is_torch_available() if is_torch_available()
else {} else {}
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment