Unverified Commit de6d19ea authored by amyeroberts's avatar amyeroberts Committed by GitHub
Browse files

Add segmentation + object detection image processors (#20160)

* Add transforms for object detection

* DETR models + Yolos

* Scrappy additions

* Maskformer image processor

* Fix up; MaskFormer tests

* Update owlvit processor

* Add to docs

* OwlViT tests

* Update pad logic

* Remove changes to transforms

* Import fn directly

* Update to include pad transformation

* Remove uninstended changes

* Add new owlvit post processing function

* Tidy up

* Fix copies

* Fix some copies

* Include device fix

* Fix scipy imports

* Update _pad_image

* Update padding functionality

* Fix bug

* Properly handle ignore index

* Fix up

* Remove defaults to None in docstrings

* Fix docstrings & docs

* Fix sizes bug

* Resolve conflicts in init

* Cast to float after resizing

* Tidy & add size if missing

* Allow kwards when processing for owlvit

* Update test values
parent ae3cbc95
...@@ -32,6 +32,16 @@ This model was contributed by [DepuMeng](https://huggingface.co/DepuMeng). The o ...@@ -32,6 +32,16 @@ This model was contributed by [DepuMeng](https://huggingface.co/DepuMeng). The o
[[autodoc]] ConditionalDetrConfig [[autodoc]] ConditionalDetrConfig
## ConditionalDetrImageProcessor
[[autodoc]] ConditionalDetrImageProcessor
- preprocess
- pad_and_create_pixel_mask
- post_process_object_detection
- post_process_instance_segmentation
- post_process_semantic_segmentation
- post_process_panoptic_segmentation
## ConditionalDetrFeatureExtractor ## ConditionalDetrFeatureExtractor
[[autodoc]] ConditionalDetrFeatureExtractor [[autodoc]] ConditionalDetrFeatureExtractor
......
...@@ -33,6 +33,13 @@ alt="drawing" width="600"/> ...@@ -33,6 +33,13 @@ alt="drawing" width="600"/>
This model was contributed by [nielsr](https://huggingface.co/nielsr). The original code can be found [here](https://github.com/fundamentalvision/Deformable-DETR). This model was contributed by [nielsr](https://huggingface.co/nielsr). The original code can be found [here](https://github.com/fundamentalvision/Deformable-DETR).
## DeformableDetrImageProcessor
[[autodoc]] DeformableDetrImageProcessor
- preprocess
- pad_and_create_pixel_mask
- post_process_object_detection
## DeformableDetrFeatureExtractor ## DeformableDetrFeatureExtractor
[[autodoc]] DeformableDetrFeatureExtractor [[autodoc]] DeformableDetrFeatureExtractor
......
...@@ -166,6 +166,15 @@ mean Average Precision (mAP) and Panoptic Quality (PQ). The latter objects are i ...@@ -166,6 +166,15 @@ mean Average Precision (mAP) and Panoptic Quality (PQ). The latter objects are i
[[autodoc]] DetrConfig [[autodoc]] DetrConfig
## DetrImageProcessor
[[autodoc]] DetrImageProcessor
- preprocess
- post_process_object_detection
- post_process_semantic_segmentation
- post_process_instance_segmentation
- post_process_panoptic_segmentation
## DetrFeatureExtractor ## DetrFeatureExtractor
[[autodoc]] DetrFeatureExtractor [[autodoc]] DetrFeatureExtractor
......
...@@ -57,6 +57,15 @@ This model was contributed by [francesco](https://huggingface.co/francesco). The ...@@ -57,6 +57,15 @@ This model was contributed by [francesco](https://huggingface.co/francesco). The
[[autodoc]] MaskFormerConfig [[autodoc]] MaskFormerConfig
## MaskFormerImageProcessor
[[autodoc]] MaskFormerImageProcessor
- preprocess
- encode_inputs
- post_process_semantic_segmentation
- post_process_instance_segmentation
- post_process_panoptic_segmentation
## MaskFormerFeatureExtractor ## MaskFormerFeatureExtractor
[[autodoc]] MaskFormerFeatureExtractor [[autodoc]] MaskFormerFeatureExtractor
......
...@@ -76,6 +76,13 @@ This model was contributed by [adirik](https://huggingface.co/adirik). The origi ...@@ -76,6 +76,13 @@ This model was contributed by [adirik](https://huggingface.co/adirik). The origi
[[autodoc]] OwlViTVisionConfig [[autodoc]] OwlViTVisionConfig
## OwlViTImageProcessor
[[autodoc]] OwlViTImageProcessor
- preprocess
- post_process
- post_process_image_guided_detection
## OwlViTFeatureExtractor ## OwlViTFeatureExtractor
[[autodoc]] OwlViTFeatureExtractor [[autodoc]] OwlViTFeatureExtractor
......
...@@ -37,6 +37,12 @@ This model was contributed by [nielsr](https://huggingface.co/nielsr). The origi ...@@ -37,6 +37,12 @@ This model was contributed by [nielsr](https://huggingface.co/nielsr). The origi
[[autodoc]] YolosConfig [[autodoc]] YolosConfig
## YolosImageProcessor
[[autodoc]] YolosImageProcessor
- preprocess
- pad
- post_process_object_detection
## YolosFeatureExtractor ## YolosFeatureExtractor
......
...@@ -736,11 +736,15 @@ else: ...@@ -736,11 +736,15 @@ else:
_import_structure["image_utils"] = ["ImageFeatureExtractionMixin"] _import_structure["image_utils"] = ["ImageFeatureExtractionMixin"]
_import_structure["models.beit"].extend(["BeitFeatureExtractor", "BeitImageProcessor"]) _import_structure["models.beit"].extend(["BeitFeatureExtractor", "BeitImageProcessor"])
_import_structure["models.clip"].extend(["CLIPFeatureExtractor", "CLIPImageProcessor"]) _import_structure["models.clip"].extend(["CLIPFeatureExtractor", "CLIPImageProcessor"])
_import_structure["models.conditional_detr"].append("ConditionalDetrFeatureExtractor") _import_structure["models.conditional_detr"].extend(
["ConditionalDetrFeatureExtractor", "ConditionalDetrImageProcessor"]
)
_import_structure["models.convnext"].extend(["ConvNextFeatureExtractor", "ConvNextImageProcessor"]) _import_structure["models.convnext"].extend(["ConvNextFeatureExtractor", "ConvNextImageProcessor"])
_import_structure["models.deformable_detr"].append("DeformableDetrFeatureExtractor") _import_structure["models.deformable_detr"].extend(
["DeformableDetrFeatureExtractor", "DeformableDetrImageProcessor"]
)
_import_structure["models.deit"].extend(["DeiTFeatureExtractor", "DeiTImageProcessor"]) _import_structure["models.deit"].extend(["DeiTFeatureExtractor", "DeiTImageProcessor"])
_import_structure["models.detr"].append("DetrFeatureExtractor") _import_structure["models.detr"].extend(["DetrFeatureExtractor", "DetrImageProcessor"])
_import_structure["models.donut"].extend(["DonutFeatureExtractor", "DonutImageProcessor"]) _import_structure["models.donut"].extend(["DonutFeatureExtractor", "DonutImageProcessor"])
_import_structure["models.dpt"].extend(["DPTFeatureExtractor", "DPTImageProcessor"]) _import_structure["models.dpt"].extend(["DPTFeatureExtractor", "DPTImageProcessor"])
_import_structure["models.flava"].extend(["FlavaFeatureExtractor", "FlavaImageProcessor", "FlavaProcessor"]) _import_structure["models.flava"].extend(["FlavaFeatureExtractor", "FlavaImageProcessor", "FlavaProcessor"])
...@@ -749,18 +753,18 @@ else: ...@@ -749,18 +753,18 @@ else:
_import_structure["models.layoutlmv2"].extend(["LayoutLMv2FeatureExtractor", "LayoutLMv2ImageProcessor"]) _import_structure["models.layoutlmv2"].extend(["LayoutLMv2FeatureExtractor", "LayoutLMv2ImageProcessor"])
_import_structure["models.layoutlmv3"].extend(["LayoutLMv3FeatureExtractor", "LayoutLMv3ImageProcessor"]) _import_structure["models.layoutlmv3"].extend(["LayoutLMv3FeatureExtractor", "LayoutLMv3ImageProcessor"])
_import_structure["models.levit"].extend(["LevitFeatureExtractor", "LevitImageProcessor"]) _import_structure["models.levit"].extend(["LevitFeatureExtractor", "LevitImageProcessor"])
_import_structure["models.maskformer"].append("MaskFormerFeatureExtractor") _import_structure["models.maskformer"].extend(["MaskFormerFeatureExtractor", "MaskFormerImageProcessor"])
_import_structure["models.mobilenet_v1"].extend(["MobileNetV1FeatureExtractor", "MobileNetV1ImageProcessor"]) _import_structure["models.mobilenet_v1"].extend(["MobileNetV1FeatureExtractor", "MobileNetV1ImageProcessor"])
_import_structure["models.mobilenet_v2"].extend(["MobileNetV2FeatureExtractor", "MobileNetV2ImageProcessor"]) _import_structure["models.mobilenet_v2"].extend(["MobileNetV2FeatureExtractor", "MobileNetV2ImageProcessor"])
_import_structure["models.mobilevit"].extend(["MobileViTFeatureExtractor", "MobileViTImageProcessor"]) _import_structure["models.mobilevit"].extend(["MobileViTFeatureExtractor", "MobileViTImageProcessor"])
_import_structure["models.owlvit"].append("OwlViTFeatureExtractor") _import_structure["models.owlvit"].extend(["OwlViTFeatureExtractor", "OwlViTImageProcessor"])
_import_structure["models.perceiver"].extend(["PerceiverFeatureExtractor", "PerceiverImageProcessor"]) _import_structure["models.perceiver"].extend(["PerceiverFeatureExtractor", "PerceiverImageProcessor"])
_import_structure["models.poolformer"].extend(["PoolFormerFeatureExtractor", "PoolFormerImageProcessor"]) _import_structure["models.poolformer"].extend(["PoolFormerFeatureExtractor", "PoolFormerImageProcessor"])
_import_structure["models.segformer"].extend(["SegformerFeatureExtractor", "SegformerImageProcessor"]) _import_structure["models.segformer"].extend(["SegformerFeatureExtractor", "SegformerImageProcessor"])
_import_structure["models.videomae"].extend(["VideoMAEFeatureExtractor", "VideoMAEImageProcessor"]) _import_structure["models.videomae"].extend(["VideoMAEFeatureExtractor", "VideoMAEImageProcessor"])
_import_structure["models.vilt"].extend(["ViltFeatureExtractor", "ViltImageProcessor", "ViltProcessor"]) _import_structure["models.vilt"].extend(["ViltFeatureExtractor", "ViltImageProcessor", "ViltProcessor"])
_import_structure["models.vit"].extend(["ViTFeatureExtractor", "ViTImageProcessor"]) _import_structure["models.vit"].extend(["ViTFeatureExtractor", "ViTImageProcessor"])
_import_structure["models.yolos"].extend(["YolosFeatureExtractor"]) _import_structure["models.yolos"].extend(["YolosFeatureExtractor", "YolosImageProcessor"])
# Timm-backed objects # Timm-backed objects
try: try:
...@@ -3869,11 +3873,11 @@ if TYPE_CHECKING: ...@@ -3869,11 +3873,11 @@ if TYPE_CHECKING:
from .image_utils import ImageFeatureExtractionMixin from .image_utils import ImageFeatureExtractionMixin
from .models.beit import BeitFeatureExtractor, BeitImageProcessor from .models.beit import BeitFeatureExtractor, BeitImageProcessor
from .models.clip import CLIPFeatureExtractor, CLIPImageProcessor from .models.clip import CLIPFeatureExtractor, CLIPImageProcessor
from .models.conditional_detr import ConditionalDetrFeatureExtractor from .models.conditional_detr import ConditionalDetrFeatureExtractor, ConditionalDetrImageProcessor
from .models.convnext import ConvNextFeatureExtractor, ConvNextImageProcessor from .models.convnext import ConvNextFeatureExtractor, ConvNextImageProcessor
from .models.deformable_detr import DeformableDetrFeatureExtractor from .models.deformable_detr import DeformableDetrFeatureExtractor, DeformableDetrImageProcessor
from .models.deit import DeiTFeatureExtractor, DeiTImageProcessor from .models.deit import DeiTFeatureExtractor, DeiTImageProcessor
from .models.detr import DetrFeatureExtractor from .models.detr import DetrFeatureExtractor, DetrImageProcessor
from .models.donut import DonutFeatureExtractor, DonutImageProcessor from .models.donut import DonutFeatureExtractor, DonutImageProcessor
from .models.dpt import DPTFeatureExtractor, DPTImageProcessor from .models.dpt import DPTFeatureExtractor, DPTImageProcessor
from .models.flava import FlavaFeatureExtractor, FlavaImageProcessor, FlavaProcessor from .models.flava import FlavaFeatureExtractor, FlavaImageProcessor, FlavaProcessor
...@@ -3882,18 +3886,18 @@ if TYPE_CHECKING: ...@@ -3882,18 +3886,18 @@ if TYPE_CHECKING:
from .models.layoutlmv2 import LayoutLMv2FeatureExtractor, LayoutLMv2ImageProcessor from .models.layoutlmv2 import LayoutLMv2FeatureExtractor, LayoutLMv2ImageProcessor
from .models.layoutlmv3 import LayoutLMv3FeatureExtractor, LayoutLMv3ImageProcessor from .models.layoutlmv3 import LayoutLMv3FeatureExtractor, LayoutLMv3ImageProcessor
from .models.levit import LevitFeatureExtractor, LevitImageProcessor from .models.levit import LevitFeatureExtractor, LevitImageProcessor
from .models.maskformer import MaskFormerFeatureExtractor from .models.maskformer import MaskFormerFeatureExtractor, MaskFormerImageProcessor
from .models.mobilenet_v1 import MobileNetV1FeatureExtractor, MobileNetV1ImageProcessor from .models.mobilenet_v1 import MobileNetV1FeatureExtractor, MobileNetV1ImageProcessor
from .models.mobilenet_v2 import MobileNetV2FeatureExtractor, MobileNetV2ImageProcessor from .models.mobilenet_v2 import MobileNetV2FeatureExtractor, MobileNetV2ImageProcessor
from .models.mobilevit import MobileViTFeatureExtractor, MobileViTImageProcessor from .models.mobilevit import MobileViTFeatureExtractor, MobileViTImageProcessor
from .models.owlvit import OwlViTFeatureExtractor from .models.owlvit import OwlViTFeatureExtractor, OwlViTImageProcessor
from .models.perceiver import PerceiverFeatureExtractor, PerceiverImageProcessor from .models.perceiver import PerceiverFeatureExtractor, PerceiverImageProcessor
from .models.poolformer import PoolFormerFeatureExtractor, PoolFormerImageProcessor from .models.poolformer import PoolFormerFeatureExtractor, PoolFormerImageProcessor
from .models.segformer import SegformerFeatureExtractor, SegformerImageProcessor from .models.segformer import SegformerFeatureExtractor, SegformerImageProcessor
from .models.videomae import VideoMAEFeatureExtractor, VideoMAEImageProcessor from .models.videomae import VideoMAEFeatureExtractor, VideoMAEImageProcessor
from .models.vilt import ViltFeatureExtractor, ViltImageProcessor, ViltProcessor from .models.vilt import ViltFeatureExtractor, ViltImageProcessor, ViltProcessor
from .models.vit import ViTFeatureExtractor, ViTImageProcessor from .models.vit import ViTFeatureExtractor, ViTImageProcessor
from .models.yolos import YolosFeatureExtractor from .models.yolos import YolosFeatureExtractor, YolosImageProcessor
# Modeling # Modeling
try: try:
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
# limitations under the License. # limitations under the License.
import os import os
from typing import TYPE_CHECKING, List, Tuple, Union from typing import TYPE_CHECKING, Dict, Iterable, List, Tuple, Union
import numpy as np import numpy as np
from packaging import version from packaging import version
...@@ -163,6 +163,47 @@ def get_image_size(image: np.ndarray, channel_dim: ChannelDimension = None) -> T ...@@ -163,6 +163,47 @@ def get_image_size(image: np.ndarray, channel_dim: ChannelDimension = None) -> T
raise ValueError(f"Unsupported data format: {channel_dim}") raise ValueError(f"Unsupported data format: {channel_dim}")
def is_valid_annotation_coco_detection(annotation: Dict[str, Union[List, Tuple]]) -> bool:
if (
isinstance(annotation, dict)
and "image_id" in annotation
and "annotations" in annotation
and isinstance(annotation["annotations"], (list, tuple))
and (
# an image can have no annotations
len(annotation["annotations"]) == 0
or isinstance(annotation["annotations"][0], dict)
)
):
return True
return False
def is_valid_annotation_coco_panoptic(annotation: Dict[str, Union[List, Tuple]]) -> bool:
if (
isinstance(annotation, dict)
and "image_id" in annotation
and "segments_info" in annotation
and "file_name" in annotation
and isinstance(annotation["segments_info"], (list, tuple))
and (
# an image can have no segments
len(annotation["segments_info"]) == 0
or isinstance(annotation["segments_info"][0], dict)
)
):
return True
return False
def valid_coco_detection_annotations(annotations: Iterable[Dict[str, Union[List, Tuple]]]) -> bool:
return all(is_valid_annotation_coco_detection(ann) for ann in annotations)
def valid_coco_panoptic_annotations(annotations: Iterable[Dict[str, Union[List, Tuple]]]) -> bool:
return all(is_valid_annotation_coco_panoptic(ann) for ann in annotations)
def load_image(image: Union[str, "PIL.Image.Image"]) -> "PIL.Image.Image": def load_image(image: Union[str, "PIL.Image.Image"]) -> "PIL.Image.Image":
""" """
Loads `image` to a PIL Image. Loads `image` to a PIL Image.
......
...@@ -39,10 +39,14 @@ IMAGE_PROCESSOR_MAPPING_NAMES = OrderedDict( ...@@ -39,10 +39,14 @@ IMAGE_PROCESSOR_MAPPING_NAMES = OrderedDict(
[ [
("beit", "BeitImageProcessor"), ("beit", "BeitImageProcessor"),
("clip", "CLIPImageProcessor"), ("clip", "CLIPImageProcessor"),
("clipseg", "ViTImageProcessor"),
("conditional_detr", "ConditionalDetrImageProcessor"),
("convnext", "ConvNextImageProcessor"), ("convnext", "ConvNextImageProcessor"),
("cvt", "ConvNextImageProcessor"), ("cvt", "ConvNextImageProcessor"),
("data2vec-vision", "BeitImageProcessor"), ("data2vec-vision", "BeitImageProcessor"),
("deformable_detr", "DeformableDetrImageProcessor"),
("deit", "DeiTImageProcessor"), ("deit", "DeiTImageProcessor"),
("detr", "DetrImageProcessor"),
("dinat", "ViTImageProcessor"), ("dinat", "ViTImageProcessor"),
("donut-swin", "DonutImageProcessor"), ("donut-swin", "DonutImageProcessor"),
("dpt", "DPTImageProcessor"), ("dpt", "DPTImageProcessor"),
...@@ -53,10 +57,14 @@ IMAGE_PROCESSOR_MAPPING_NAMES = OrderedDict( ...@@ -53,10 +57,14 @@ IMAGE_PROCESSOR_MAPPING_NAMES = OrderedDict(
("layoutlmv2", "LayoutLMv2ImageProcessor"), ("layoutlmv2", "LayoutLMv2ImageProcessor"),
("layoutlmv3", "LayoutLMv3ImageProcessor"), ("layoutlmv3", "LayoutLMv3ImageProcessor"),
("levit", "LevitImageProcessor"), ("levit", "LevitImageProcessor"),
("maskformer", "MaskFormerImageProcessor"),
("mobilenet_v1", "MobileNetV1ImageProcessor"), ("mobilenet_v1", "MobileNetV1ImageProcessor"),
("mobilenet_v2", "MobileNetV2ImageProcessor"), ("mobilenet_v2", "MobileNetV2ImageProcessor"),
("mobilenet_v2", "MobileNetV2ImageProcessor"),
("mobilevit", "MobileViTImageProcessor"),
("mobilevit", "MobileViTImageProcessor"), ("mobilevit", "MobileViTImageProcessor"),
("nat", "ViTImageProcessor"), ("nat", "ViTImageProcessor"),
("owlvit", "OwlViTImageProcessor"),
("perceiver", "PerceiverImageProcessor"), ("perceiver", "PerceiverImageProcessor"),
("poolformer", "PoolFormerImageProcessor"), ("poolformer", "PoolFormerImageProcessor"),
("regnet", "ConvNextImageProcessor"), ("regnet", "ConvNextImageProcessor"),
...@@ -64,6 +72,7 @@ IMAGE_PROCESSOR_MAPPING_NAMES = OrderedDict( ...@@ -64,6 +72,7 @@ IMAGE_PROCESSOR_MAPPING_NAMES = OrderedDict(
("segformer", "SegformerImageProcessor"), ("segformer", "SegformerImageProcessor"),
("swin", "ViTImageProcessor"), ("swin", "ViTImageProcessor"),
("swinv2", "ViTImageProcessor"), ("swinv2", "ViTImageProcessor"),
("table-transformer", "DetrImageProcessor"),
("van", "ConvNextImageProcessor"), ("van", "ConvNextImageProcessor"),
("videomae", "VideoMAEImageProcessor"), ("videomae", "VideoMAEImageProcessor"),
("vilt", "ViltImageProcessor"), ("vilt", "ViltImageProcessor"),
...@@ -71,6 +80,7 @@ IMAGE_PROCESSOR_MAPPING_NAMES = OrderedDict( ...@@ -71,6 +80,7 @@ IMAGE_PROCESSOR_MAPPING_NAMES = OrderedDict(
("vit_mae", "ViTImageProcessor"), ("vit_mae", "ViTImageProcessor"),
("vit_msn", "ViTImageProcessor"), ("vit_msn", "ViTImageProcessor"),
("xclip", "CLIPImageProcessor"), ("xclip", "CLIPImageProcessor"),
("yolos", "YolosImageProcessor"),
] ]
) )
...@@ -113,7 +123,7 @@ def get_image_processor_config( ...@@ -113,7 +123,7 @@ def get_image_processor_config(
**kwargs, **kwargs,
): ):
""" """
Loads the image processor configuration from a pretrained model imag processor configuration. # FIXME Loads the image processor configuration from a pretrained model image processor configuration.
Args: Args:
pretrained_model_name_or_path (`str` or `os.PathLike`): pretrained_model_name_or_path (`str` or `os.PathLike`):
......
...@@ -36,6 +36,7 @@ except OptionalDependencyNotAvailable: ...@@ -36,6 +36,7 @@ except OptionalDependencyNotAvailable:
pass pass
else: else:
_import_structure["feature_extraction_conditional_detr"] = ["ConditionalDetrFeatureExtractor"] _import_structure["feature_extraction_conditional_detr"] = ["ConditionalDetrFeatureExtractor"]
_import_structure["image_processing_conditional_detr"] = ["ConditionalDetrImageProcessor"]
try: try:
if not is_timm_available(): if not is_timm_available():
...@@ -66,6 +67,7 @@ if TYPE_CHECKING: ...@@ -66,6 +67,7 @@ if TYPE_CHECKING:
pass pass
else: else:
from .feature_extraction_conditional_detr import ConditionalDetrFeatureExtractor from .feature_extraction_conditional_detr import ConditionalDetrFeatureExtractor
from .image_processing_conditional_detr import ConditionalDetrImageProcessor
try: try:
if not is_timm_available(): if not is_timm_available():
......
...@@ -32,6 +32,7 @@ except OptionalDependencyNotAvailable: ...@@ -32,6 +32,7 @@ except OptionalDependencyNotAvailable:
pass pass
else: else:
_import_structure["feature_extraction_deformable_detr"] = ["DeformableDetrFeatureExtractor"] _import_structure["feature_extraction_deformable_detr"] = ["DeformableDetrFeatureExtractor"]
_import_structure["image_processing_deformable_detr"] = ["DeformableDetrImageProcessor"]
try: try:
if not is_timm_available(): if not is_timm_available():
...@@ -57,6 +58,7 @@ if TYPE_CHECKING: ...@@ -57,6 +58,7 @@ if TYPE_CHECKING:
pass pass
else: else:
from .feature_extraction_deformable_detr import DeformableDetrFeatureExtractor from .feature_extraction_deformable_detr import DeformableDetrFeatureExtractor
from .image_processing_deformable_detr import DeformableDetrImageProcessor
try: try:
if not is_timm_available(): if not is_timm_available():
......
...@@ -30,6 +30,7 @@ except OptionalDependencyNotAvailable: ...@@ -30,6 +30,7 @@ except OptionalDependencyNotAvailable:
pass pass
else: else:
_import_structure["feature_extraction_detr"] = ["DetrFeatureExtractor"] _import_structure["feature_extraction_detr"] = ["DetrFeatureExtractor"]
_import_structure["image_processing_detr"] = ["DetrImageProcessor"]
try: try:
if not is_timm_available(): if not is_timm_available():
...@@ -56,6 +57,7 @@ if TYPE_CHECKING: ...@@ -56,6 +57,7 @@ if TYPE_CHECKING:
pass pass
else: else:
from .feature_extraction_detr import DetrFeatureExtractor from .feature_extraction_detr import DetrFeatureExtractor
from .image_processing_detr import DetrImageProcessor
try: try:
if not is_timm_available(): if not is_timm_available():
......
This diff is collapsed.
...@@ -1589,7 +1589,7 @@ class DetrForSegmentation(DetrPreTrainedModel): ...@@ -1589,7 +1589,7 @@ class DetrForSegmentation(DetrPreTrainedModel):
>>> import numpy >>> import numpy
>>> from transformers import DetrFeatureExtractor, DetrForSegmentation >>> from transformers import DetrFeatureExtractor, DetrForSegmentation
>>> from transformers.models.detr.feature_extraction_detr import rgb_to_id >>> from transformers.image_transforms import rgb_to_id
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw) >>> image = Image.open(requests.get(url, stream=True).raw)
...@@ -2289,8 +2289,6 @@ def generalized_box_iou(boxes1, boxes2): ...@@ -2289,8 +2289,6 @@ def generalized_box_iou(boxes1, boxes2):
# below: taken from https://github.com/facebookresearch/detr/blob/master/util/misc.py#L306 # below: taken from https://github.com/facebookresearch/detr/blob/master/util/misc.py#L306
def _max_by_axis(the_list): def _max_by_axis(the_list):
# type: (List[List[int]]) -> List[int] # type: (List[List[int]]) -> List[int]
maxes = the_list[0] maxes = the_list[0]
......
...@@ -32,6 +32,7 @@ except OptionalDependencyNotAvailable: ...@@ -32,6 +32,7 @@ except OptionalDependencyNotAvailable:
pass pass
else: else:
_import_structure["feature_extraction_maskformer"] = ["MaskFormerFeatureExtractor"] _import_structure["feature_extraction_maskformer"] = ["MaskFormerFeatureExtractor"]
_import_structure["image_processing_maskformer"] = ["MaskFormerImageProcessor"]
try: try:
...@@ -63,6 +64,7 @@ if TYPE_CHECKING: ...@@ -63,6 +64,7 @@ if TYPE_CHECKING:
pass pass
else: else:
from .feature_extraction_maskformer import MaskFormerFeatureExtractor from .feature_extraction_maskformer import MaskFormerFeatureExtractor
from .image_processing_maskformer import MaskFormerImageProcessor
try: try:
if not is_torch_available(): if not is_torch_available():
raise OptionalDependencyNotAvailable() raise OptionalDependencyNotAvailable()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment