Unverified Commit ba3264b4 authored by amyeroberts's avatar amyeroberts Committed by GitHub
Browse files

Image Feature Extraction pipeline (#28216)



* Draft pipeline

* Fixup

* Fix docstrings

* Update doctest

* Update pipeline_model_mapping

* Update docstring

* Update tests

* Update src/transformers/pipelines/image_feature_extraction.py
Co-authored-by: default avatarOmar Sanseviero <osanseviero@gmail.com>

* Fix docstrings - review comments

* Remove pipeline mapping for composite vision models

* Add to pipeline tests

* Remove for flava (multimodal)

* safe pil import

* Add requirements for pipeline run

* Account for super slow efficientnet

* Review comments

* Fix tests

* Swap order of kwargs

* Use build_pipeline_init_args

* Add back FE pipeline for Vilt

* Include image_processor_kwargs in docstring

* Mark test as flaky

* Update TODO

* Update tests/pipelines/test_pipelines_image_feature_extraction.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Add license header

---------
Co-authored-by: default avatarOmar Sanseviero <osanseviero@gmail.com>
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>
parent 7addc934
......@@ -182,7 +182,7 @@ class DetrModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
)
pipeline_model_mapping = (
{
"feature-extraction": DetrModel,
"image-feature-extraction": DetrModel,
"image-segmentation": DetrForSegmentation,
"object-detection": DetrForObjectDetection,
}
......
......@@ -207,7 +207,7 @@ class DinatModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
else ()
)
pipeline_model_mapping = (
{"feature-extraction": DinatModel, "image-classification": DinatForImageClassification}
{"image-feature-extraction": DinatModel, "image-classification": DinatForImageClassification}
if is_torch_available()
else {}
)
......
......@@ -217,7 +217,7 @@ class Dinov2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
else ()
)
pipeline_model_mapping = (
{"feature-extraction": Dinov2Model, "image-classification": Dinov2ForImageClassification}
{"image-feature-extraction": Dinov2Model, "image-classification": Dinov2ForImageClassification}
if is_torch_available()
else {}
)
......
......@@ -145,7 +145,7 @@ class DonutSwinModelTester:
@require_torch
class DonutSwinModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (DonutSwinModel,) if is_torch_available() else ()
pipeline_model_mapping = {"feature-extraction": DonutSwinModel} if is_torch_available() else {}
pipeline_model_mapping = {"image-feature-extraction": DonutSwinModel} if is_torch_available() else {}
fx_compatible = True
test_pruning = False
......
......@@ -163,7 +163,7 @@ class DPTModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
pipeline_model_mapping = (
{
"depth-estimation": DPTForDepthEstimation,
"feature-extraction": DPTModel,
"image-feature-extraction": DPTModel,
"image-segmentation": DPTForSemanticSegmentation,
}
if is_torch_available()
......
......@@ -190,7 +190,7 @@ class EfficientFormerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.T
)
pipeline_model_mapping = (
{
"feature-extraction": EfficientFormerModel,
"image-feature-extraction": EfficientFormerModel,
"image-classification": (
EfficientFormerForImageClassification,
EfficientFormerForImageClassificationWithTeacher,
......
......@@ -130,7 +130,7 @@ class EfficientNetModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.Test
all_model_classes = (EfficientNetModel, EfficientNetForImageClassification) if is_torch_available() else ()
pipeline_model_mapping = (
{"feature-extraction": EfficientNetModel, "image-classification": EfficientNetForImageClassification}
{"image-feature-extraction": EfficientNetModel, "image-classification": EfficientNetForImageClassification}
if is_torch_available()
else {}
)
......@@ -216,6 +216,12 @@ class EfficientNetModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.Test
model = EfficientNetModel.from_pretrained(model_name)
self.assertIsNotNone(model)
@is_pipeline_test
@require_vision
@slow
def test_pipeline_image_feature_extraction(self):
super().test_pipeline_image_feature_extraction()
@is_pipeline_test
@require_vision
@slow
......
......@@ -238,7 +238,7 @@ class FocalNetModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase
else ()
)
pipeline_model_mapping = (
{"feature-extraction": FocalNetModel, "image-classification": FocalNetForImageClassification}
{"image-feature-extraction": FocalNetModel, "image-classification": FocalNetForImageClassification}
if is_torch_available()
else {}
)
......
......@@ -146,7 +146,9 @@ class GLPNModelTester:
class GLPNModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (GLPNModel, GLPNForDepthEstimation) if is_torch_available() else ()
pipeline_model_mapping = (
{"depth-estimation": GLPNForDepthEstimation, "feature-extraction": GLPNModel} if is_torch_available() else {}
{"depth-estimation": GLPNForDepthEstimation, "image-feature-extraction": GLPNModel}
if is_torch_available()
else {}
)
test_head_masking = False
......
......@@ -271,7 +271,7 @@ class ImageGPTModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM
)
all_generative_model_classes = (ImageGPTForCausalImageModeling,) if is_torch_available() else ()
pipeline_model_mapping = (
{"feature-extraction": ImageGPTModel, "image-classification": ImageGPTForImageClassification}
{"image-feature-extraction": ImageGPTModel, "image-classification": ImageGPTForImageClassification}
if is_torch_available()
else {}
)
......
......@@ -176,7 +176,7 @@ class LevitModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
)
pipeline_model_mapping = (
{
"feature-extraction": LevitModel,
"image-feature-extraction": LevitModel,
"image-classification": (LevitForImageClassification, LevitForImageClassificationWithTeacher),
}
if is_torch_available()
......
......@@ -197,7 +197,7 @@ class Mask2FormerModelTester:
@require_torch
class Mask2FormerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (Mask2FormerModel, Mask2FormerForUniversalSegmentation) if is_torch_available() else ()
pipeline_model_mapping = {"feature-extraction": Mask2FormerModel} if is_torch_available() else {}
pipeline_model_mapping = {"image-feature-extraction": Mask2FormerModel} if is_torch_available() else {}
is_encoder_decoder = False
test_pruning = False
......
......@@ -197,7 +197,7 @@ class MaskFormerModelTester:
class MaskFormerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (MaskFormerModel, MaskFormerForInstanceSegmentation) if is_torch_available() else ()
pipeline_model_mapping = (
{"feature-extraction": MaskFormerModel, "image-segmentation": MaskFormerForInstanceSegmentation}
{"image-feature-extraction": MaskFormerModel, "image-segmentation": MaskFormerForInstanceSegmentation}
if is_torch_available()
else {}
)
......
......@@ -31,7 +31,7 @@ if is_torch_available():
import torch
from torch import nn
from transformers import MgpstrForSceneTextRecognition
from transformers import MgpstrForSceneTextRecognition, MgpstrModel
if is_vision_available():
......@@ -118,7 +118,11 @@ class MgpstrModelTester:
@require_torch
class MgpstrModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (MgpstrForSceneTextRecognition,) if is_torch_available() else ()
pipeline_model_mapping = {"feature-extraction": MgpstrForSceneTextRecognition} if is_torch_available() else {}
pipeline_model_mapping = (
{"feature-extraction": MgpstrForSceneTextRecognition, "image-feature-extraction": MgpstrModel}
if is_torch_available()
else {}
)
fx_compatible = False
test_pruning = False
......
......@@ -147,7 +147,7 @@ class MobileNetV1ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestC
all_model_classes = (MobileNetV1Model, MobileNetV1ForImageClassification) if is_torch_available() else ()
pipeline_model_mapping = (
{"feature-extraction": MobileNetV1Model, "image-classification": MobileNetV1ForImageClassification}
{"image-feature-extraction": MobileNetV1Model, "image-classification": MobileNetV1ForImageClassification}
if is_torch_available()
else {}
)
......
......@@ -195,7 +195,7 @@ class MobileNetV2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestC
)
pipeline_model_mapping = (
{
"feature-extraction": MobileNetV2Model,
"image-feature-extraction": MobileNetV2Model,
"image-classification": MobileNetV2ForImageClassification,
"image-segmentation": MobileNetV2ForSemanticSegmentation,
}
......
......@@ -188,7 +188,7 @@ class MobileViTModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCas
)
pipeline_model_mapping = (
{
"feature-extraction": MobileViTModel,
"image-feature-extraction": MobileViTModel,
"image-classification": MobileViTForImageClassification,
"image-segmentation": MobileViTForSemanticSegmentation,
}
......
......@@ -190,7 +190,7 @@ class MobileViTV2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestC
pipeline_model_mapping = (
{
"feature-extraction": MobileViTV2Model,
"image-feature-extraction": MobileViTV2Model,
"image-classification": MobileViTV2ForImageClassification,
"image-segmentation": MobileViTV2ForSemanticSegmentation,
}
......
......@@ -204,7 +204,7 @@ class NatModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
else ()
)
pipeline_model_mapping = (
{"feature-extraction": NatModel, "image-classification": NatForImageClassification}
{"image-feature-extraction": NatModel, "image-classification": NatForImageClassification}
if is_torch_available()
else {}
)
......
......@@ -433,7 +433,10 @@ class Owlv2ModelTester:
class Owlv2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (Owlv2Model,) if is_torch_available() else ()
pipeline_model_mapping = (
{"feature-extraction": Owlv2Model, "zero-shot-object-detection": Owlv2ForObjectDetection}
{
"feature-extraction": Owlv2Model,
"zero-shot-object-detection": Owlv2ForObjectDetection,
}
if is_torch_available()
else {}
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment