Unverified Commit 39084ca6 authored by Lysandre Debut's avatar Lysandre Debut Committed by GitHub
Browse files

Add the ImageClassificationPipeline (#11598)



* Add the ImageClassificationPipeline

* Code review
Co-authored-by: default avatarpatrickvonplaten <patrick.v.platen@gmail.com>

* Have `load_image` at the module level
Co-authored-by: default avatarpatrickvonplaten <patrick.v.platen@gmail.com>
parent e7bff0aa
...@@ -37,6 +37,7 @@ jobs: ...@@ -37,6 +37,7 @@ jobs:
- name: Install dependencies - name: Install dependencies
run: | run: |
pip install --upgrade pip pip install --upgrade pip
sudo apt -y update && sudo apt install -y libsndfile1-dev
pip install .[dev] pip install .[dev]
- name: Create model files - name: Create model files
run: | run: |
......
...@@ -36,6 +36,7 @@ There are two categories of pipeline abstractions to be aware about: ...@@ -36,6 +36,7 @@ There are two categories of pipeline abstractions to be aware about:
- :class:`~transformers.ZeroShotClassificationPipeline` - :class:`~transformers.ZeroShotClassificationPipeline`
- :class:`~transformers.Text2TextGenerationPipeline` - :class:`~transformers.Text2TextGenerationPipeline`
- :class:`~transformers.TableQuestionAnsweringPipeline` - :class:`~transformers.TableQuestionAnsweringPipeline`
- :class:`~transformers.ImageClassificationPipeline`
The pipeline abstraction The pipeline abstraction
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
...@@ -79,6 +80,13 @@ FillMaskPipeline ...@@ -79,6 +80,13 @@ FillMaskPipeline
:special-members: __call__ :special-members: __call__
:members: :members:
ImageClassificationPipeline
=======================================================================================================================
.. autoclass:: transformers.ImageClassificationPipeline
:special-members: __call__
:members:
NerPipeline NerPipeline
======================================================================================================================= =======================================================================================================================
......
...@@ -128,6 +128,13 @@ AutoModelForTableQuestionAnswering ...@@ -128,6 +128,13 @@ AutoModelForTableQuestionAnswering
:members: :members:
AutoModelForImageClassification
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.AutoModelForImageClassification
:members:
TFAutoModel TFAutoModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
......
...@@ -244,6 +244,7 @@ _import_structure = { ...@@ -244,6 +244,7 @@ _import_structure = {
"CsvPipelineDataFormat", "CsvPipelineDataFormat",
"FeatureExtractionPipeline", "FeatureExtractionPipeline",
"FillMaskPipeline", "FillMaskPipeline",
"ImageClassificationPipeline",
"JsonPipelineDataFormat", "JsonPipelineDataFormat",
"NerPipeline", "NerPipeline",
"PipedPipelineDataFormat", "PipedPipelineDataFormat",
...@@ -483,6 +484,7 @@ if is_torch_available(): ...@@ -483,6 +484,7 @@ if is_torch_available():
"MODEL_WITH_LM_HEAD_MAPPING", "MODEL_WITH_LM_HEAD_MAPPING",
"AutoModel", "AutoModel",
"AutoModelForCausalLM", "AutoModelForCausalLM",
"AutoModelForImageClassification",
"AutoModelForMaskedLM", "AutoModelForMaskedLM",
"AutoModelForMultipleChoice", "AutoModelForMultipleChoice",
"AutoModelForNextSentencePrediction", "AutoModelForNextSentencePrediction",
...@@ -1640,6 +1642,7 @@ if TYPE_CHECKING: ...@@ -1640,6 +1642,7 @@ if TYPE_CHECKING:
CsvPipelineDataFormat, CsvPipelineDataFormat,
FeatureExtractionPipeline, FeatureExtractionPipeline,
FillMaskPipeline, FillMaskPipeline,
ImageClassificationPipeline,
JsonPipelineDataFormat, JsonPipelineDataFormat,
NerPipeline, NerPipeline,
PipedPipelineDataFormat, PipedPipelineDataFormat,
...@@ -1845,6 +1848,7 @@ if TYPE_CHECKING: ...@@ -1845,6 +1848,7 @@ if TYPE_CHECKING:
MODEL_WITH_LM_HEAD_MAPPING, MODEL_WITH_LM_HEAD_MAPPING,
AutoModel, AutoModel,
AutoModelForCausalLM, AutoModelForCausalLM,
AutoModelForImageClassification,
AutoModelForMaskedLM, AutoModelForMaskedLM,
AutoModelForMultipleChoice, AutoModelForMultipleChoice,
AutoModelForNextSentencePrediction, AutoModelForNextSentencePrediction,
......
...@@ -226,7 +226,7 @@ class FeatureExtractionMixin: ...@@ -226,7 +226,7 @@ class FeatureExtractionMixin:
:func:`~transformers.feature_extraction_utils.FeatureExtractionMixin.save_pretrained` method, e.g., :func:`~transformers.feature_extraction_utils.FeatureExtractionMixin.save_pretrained` method, e.g.,
``./my_model_directory/``. ``./my_model_directory/``.
- a path or url to a saved feature extractor JSON `file`, e.g., - a path or url to a saved feature extractor JSON `file`, e.g.,
``./my_model_directory/feature_extraction_config.json``. ``./my_model_directory/preprocessor_config.json``.
cache_dir (:obj:`str` or :obj:`os.PathLike`, `optional`): cache_dir (:obj:`str` or :obj:`os.PathLike`, `optional`):
Path to a directory in which a downloaded pretrained model feature extractor should be cached if the Path to a directory in which a downloaded pretrained model feature extractor should be cached if the
standard cache should not be used. standard cache should not be used.
......
...@@ -14,34 +14,26 @@ ...@@ -14,34 +14,26 @@
# limitations under the License. # limitations under the License.
""" AutoFeatureExtractor class. """ """ AutoFeatureExtractor class. """
import os
from collections import OrderedDict from collections import OrderedDict
from ...feature_extraction_utils import FeatureExtractionMixin from transformers import DeiTFeatureExtractor, Speech2TextFeatureExtractor, ViTFeatureExtractor
from ...file_utils import is_speech_available, is_vision_available
from ..wav2vec2.feature_extraction_wav2vec2 import Wav2Vec2FeatureExtractor
from .configuration_auto import replace_list_option_in_docstrings
if is_speech_available(): from ... import DeiTConfig, PretrainedConfig, Speech2TextConfig, ViTConfig, Wav2Vec2Config
from ..speech_to_text.feature_extraction_speech_to_text import Speech2TextFeatureExtractor from ...feature_extraction_utils import FeatureExtractionMixin
else:
Speech2TextFeatureExtractor = None
if is_vision_available(): # Build the list of all feature extractors
from ..deit.feature_extraction_deit import DeiTFeatureExtractor from ...file_utils import FEATURE_EXTRACTOR_NAME
from ..vit.feature_extraction_vit import ViTFeatureExtractor from ..wav2vec2.feature_extraction_wav2vec2 import Wav2Vec2FeatureExtractor
else: from .configuration_auto import AutoConfig, replace_list_option_in_docstrings
DeiTFeatureExtractor = None
ViTFeatureExtractor = None
# Build the list of all feature extractors
FEATURE_EXTRACTOR_MAPPING = OrderedDict( FEATURE_EXTRACTOR_MAPPING = OrderedDict(
[ [
("deit", DeiTFeatureExtractor), (DeiTConfig, DeiTFeatureExtractor),
("s2t", Speech2TextFeatureExtractor), (Speech2TextConfig, Speech2TextFeatureExtractor),
("vit", ViTFeatureExtractor), (ViTConfig, ViTFeatureExtractor),
("wav2vec2", Wav2Vec2FeatureExtractor), (Wav2Vec2Config, Wav2Vec2FeatureExtractor),
] ]
) )
...@@ -89,7 +81,7 @@ class AutoFeatureExtractor: ...@@ -89,7 +81,7 @@ class AutoFeatureExtractor:
:func:`~transformers.feature_extraction_utils.FeatureExtractionMixin.save_pretrained` method, e.g., :func:`~transformers.feature_extraction_utils.FeatureExtractionMixin.save_pretrained` method, e.g.,
``./my_model_directory/``. ``./my_model_directory/``.
- a path or url to a saved feature extractor JSON `file`, e.g., - a path or url to a saved feature extractor JSON `file`, e.g.,
``./my_model_directory/feature_extraction_config.json``. ``./my_model_directory/preprocessor_config.json``.
cache_dir (:obj:`str` or :obj:`os.PathLike`, `optional`): cache_dir (:obj:`str` or :obj:`os.PathLike`, `optional`):
Path to a directory in which a downloaded pretrained model feature extractor should be cached if the Path to a directory in which a downloaded pretrained model feature extractor should be cached if the
standard cache should not be used. standard cache should not be used.
...@@ -134,20 +126,29 @@ class AutoFeatureExtractor: ...@@ -134,20 +126,29 @@ class AutoFeatureExtractor:
>>> feature_extractor = AutoFeatureExtractor.from_pretrained('./test/saved_model/') >>> feature_extractor = AutoFeatureExtractor.from_pretrained('./test/saved_model/')
""" """
config = kwargs.pop("config", None)
kwargs["_from_auto"] = True
is_feature_extraction_file = os.path.isfile(pretrained_model_name_or_path)
is_directory = os.path.isdir(pretrained_model_name_or_path) and os.path.exists(
os.path.join(pretrained_model_name_or_path, FEATURE_EXTRACTOR_NAME)
)
if not is_feature_extraction_file and not is_directory:
if not isinstance(config, PretrainedConfig):
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
kwargs["_from_auto"] = True kwargs["_from_auto"] = True
config_dict, _ = FeatureExtractionMixin.get_feature_extractor_dict(pretrained_model_name_or_path, **kwargs) config_dict, _ = FeatureExtractionMixin.get_feature_extractor_dict(pretrained_model_name_or_path, **kwargs)
if "feature_extractor_type" in config_dict: if type(config) in FEATURE_EXTRACTOR_MAPPING.keys():
return FEATURE_EXTRACTOR_MAPPING[type(config)].from_dict(config_dict, **kwargs)
elif "feature_extractor_type" in config_dict:
feature_extractor_class = feature_extractor_class_from_name(config_dict["feature_extractor_type"]) feature_extractor_class = feature_extractor_class_from_name(config_dict["feature_extractor_type"])
return feature_extractor_class.from_dict(config_dict, **kwargs) return feature_extractor_class.from_dict(config_dict, **kwargs)
else:
# Fallback: use pattern matching on the string.
for pattern, feature_extractor_class in FEATURE_EXTRACTOR_MAPPING.items():
if pattern in str(pretrained_model_name_or_path):
return feature_extractor_class.from_dict(config_dict, **kwargs)
raise ValueError( raise ValueError(
f"Unrecognized model in {pretrained_model_name_or_path}. Should have a `feature_extractor_type` key in " f"Unrecognized model in {pretrained_model_name_or_path}. Should have a `feature_extractor_type` key in "
"its feature_extraction_config.json, or contain one of the following strings " f"its {FEATURE_EXTRACTOR_NAME}, or contain one of the following strings "
f"in its name: {', '.join(FEATURE_EXTRACTOR_MAPPING.keys())}" f"in its name: {', '.join(FEATURE_EXTRACTOR_MAPPING.keys())}"
) )
...@@ -97,7 +97,7 @@ class Speech2TextProcessor: ...@@ -97,7 +97,7 @@ class Speech2TextProcessor:
:meth:`~transformers.PreTrainedFeatureExtractor.save_pretrained` method, e.g., :meth:`~transformers.PreTrainedFeatureExtractor.save_pretrained` method, e.g.,
``./my_model_directory/``. ``./my_model_directory/``.
- a path or url to a saved feature extractor JSON `file`, e.g., - a path or url to a saved feature extractor JSON `file`, e.g.,
``./my_model_directory/feature_extraction_config.json``. ``./my_model_directory/preprocessor_config.json``.
**kwargs **kwargs
Additional keyword arguments passed along to both :class:`~transformers.PreTrainedFeatureExtractor` and Additional keyword arguments passed along to both :class:`~transformers.PreTrainedFeatureExtractor` and
:class:`~transformers.PreTrainedTokenizer` :class:`~transformers.PreTrainedTokenizer`
......
...@@ -96,7 +96,7 @@ class Wav2Vec2Processor: ...@@ -96,7 +96,7 @@ class Wav2Vec2Processor:
:meth:`~transformers.SequenceFeatureExtractor.save_pretrained` method, e.g., :meth:`~transformers.SequenceFeatureExtractor.save_pretrained` method, e.g.,
``./my_model_directory/``. ``./my_model_directory/``.
- a path or url to a saved feature extractor JSON `file`, e.g., - a path or url to a saved feature extractor JSON `file`, e.g.,
``./my_model_directory/feature_extraction_config.json``. ``./my_model_directory/preprocessor_config.json``.
**kwargs **kwargs
Additional keyword arguments passed along to both :class:`~transformers.SequenceFeatureExtractor` and Additional keyword arguments passed along to both :class:`~transformers.SequenceFeatureExtractor` and
:class:`~transformers.PreTrainedTokenizer` :class:`~transformers.PreTrainedTokenizer`
......
...@@ -20,9 +20,12 @@ import warnings ...@@ -20,9 +20,12 @@ import warnings
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
from ..configuration_utils import PretrainedConfig from ..configuration_utils import PretrainedConfig
from ..feature_extraction_utils import PreTrainedFeatureExtractor
from ..file_utils import is_tf_available, is_torch_available from ..file_utils import is_tf_available, is_torch_available
from ..modelcard import ModelCard from ..modelcard import ModelCard
from ..models.auto.tokenization_auto import AutoTokenizer from ..models.auto.configuration_auto import AutoConfig
from ..models.auto.feature_extraction_auto import FEATURE_EXTRACTOR_MAPPING, AutoFeatureExtractor
from ..models.auto.tokenization_auto import TOKENIZER_MAPPING, AutoTokenizer
from ..tokenization_utils import PreTrainedTokenizer from ..tokenization_utils import PreTrainedTokenizer
from ..utils import logging from ..utils import logging
from .automatic_speech_recognition import AutomaticSpeechRecognitionPipeline from .automatic_speech_recognition import AutomaticSpeechRecognitionPipeline
...@@ -40,6 +43,7 @@ from .base import ( ...@@ -40,6 +43,7 @@ from .base import (
from .conversational import Conversation, ConversationalPipeline from .conversational import Conversation, ConversationalPipeline
from .feature_extraction import FeatureExtractionPipeline from .feature_extraction import FeatureExtractionPipeline
from .fill_mask import FillMaskPipeline from .fill_mask import FillMaskPipeline
from .image_classification import ImageClassificationPipeline
from .question_answering import QuestionAnsweringArgumentHandler, QuestionAnsweringPipeline from .question_answering import QuestionAnsweringArgumentHandler, QuestionAnsweringPipeline
from .table_question_answering import TableQuestionAnsweringArgumentHandler, TableQuestionAnsweringPipeline from .table_question_answering import TableQuestionAnsweringArgumentHandler, TableQuestionAnsweringPipeline
from .text2text_generation import SummarizationPipeline, Text2TextGenerationPipeline, TranslationPipeline from .text2text_generation import SummarizationPipeline, Text2TextGenerationPipeline, TranslationPipeline
...@@ -79,6 +83,7 @@ if is_torch_available(): ...@@ -79,6 +83,7 @@ if is_torch_available():
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING, MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
AutoModel, AutoModel,
AutoModelForCausalLM, AutoModelForCausalLM,
AutoModelForImageClassification,
AutoModelForMaskedLM, AutoModelForMaskedLM,
AutoModelForQuestionAnswering, AutoModelForQuestionAnswering,
AutoModelForSeq2SeqLM, AutoModelForSeq2SeqLM,
...@@ -198,6 +203,12 @@ SUPPORTED_TASKS = { ...@@ -198,6 +203,12 @@ SUPPORTED_TASKS = {
"pt": AutoModelForCausalLM if is_torch_available() else None, "pt": AutoModelForCausalLM if is_torch_available() else None,
"default": {"model": {"pt": "microsoft/DialoGPT-medium", "tf": "microsoft/DialoGPT-medium"}}, "default": {"model": {"pt": "microsoft/DialoGPT-medium", "tf": "microsoft/DialoGPT-medium"}},
}, },
"image-classification": {
"impl": ImageClassificationPipeline,
"tf": None,
"pt": AutoModelForImageClassification if is_torch_available() else None,
"default": {"model": {"pt": "google/vit-base-patch16-224"}},
},
} }
...@@ -252,6 +263,7 @@ def pipeline( ...@@ -252,6 +263,7 @@ def pipeline(
model: Optional = None, model: Optional = None,
config: Optional[Union[str, PretrainedConfig]] = None, config: Optional[Union[str, PretrainedConfig]] = None,
tokenizer: Optional[Union[str, PreTrainedTokenizer]] = None, tokenizer: Optional[Union[str, PreTrainedTokenizer]] = None,
feature_extractor: Optional[Union[str, PreTrainedFeatureExtractor]] = None,
framework: Optional[str] = None, framework: Optional[str] = None,
revision: Optional[str] = None, revision: Optional[str] = None,
use_fast: bool = True, use_fast: bool = True,
...@@ -309,6 +321,18 @@ def pipeline( ...@@ -309,6 +321,18 @@ def pipeline(
:obj:`model` is not specified or not a string, then the default tokenizer for :obj:`config` is loaded (if :obj:`model` is not specified or not a string, then the default tokenizer for :obj:`config` is loaded (if
it is a string). However, if :obj:`config` is also not given or not a string, then the default tokenizer it is a string). However, if :obj:`config` is also not given or not a string, then the default tokenizer
for the given :obj:`task` will be loaded. for the given :obj:`task` will be loaded.
feature_extractor (:obj:`str` or :obj:`~transformers.PreTrainedFeatureExtractor`, `optional`):
The feature extractor that will be used by the pipeline to encode data for the model. This can be a model
identifier or an actual pretrained feature extractor inheriting from
:class:`~transformers.PreTrainedFeatureExtractor`.
Feature extractors are used for non-NLP models, such as Speech or Vision models as well as multi-modal
models. Multi-modal models will also require a tokenizer to be passed.
If not provided, the default feature extractor for the given :obj:`model` will be loaded (if it is a
string). If :obj:`model` is not specified or not a string, then the default feature extractor for
:obj:`config` is loaded (if it is a string). However, if :obj:`config` is also not given or not a string,
then the default feature extractor for the given :obj:`task` will be loaded.
framework (:obj:`str`, `optional`): framework (:obj:`str`, `optional`):
The framework to use, either :obj:`"pt"` for PyTorch or :obj:`"tf"` for TensorFlow. The specified framework The framework to use, either :obj:`"pt"` for PyTorch or :obj:`"tf"` for TensorFlow. The specified framework
must be installed. must be installed.
...@@ -359,19 +383,7 @@ def pipeline( ...@@ -359,19 +383,7 @@ def pipeline(
# At that point framework might still be undetermined # At that point framework might still be undetermined
model = get_default_model(targeted_task, framework, task_options) model = get_default_model(targeted_task, framework, task_options)
# Try to infer tokenizer from model or config name (if provided as str) model_name = model if isinstance(model, str) else None
if tokenizer is None:
if isinstance(model, str):
tokenizer = model
elif isinstance(config, str):
tokenizer = config
else:
# Impossible to guest what is the right tokenizer here
raise Exception(
"Impossible to guess which tokenizer to use. "
"Please provided a PreTrainedTokenizer class or a path/identifier to a pretrained tokenizer."
)
modelcard = None modelcard = None
# Try to infer modelcard from model or config name (if provided as str) # Try to infer modelcard from model or config name (if provided as str)
if isinstance(model, str): if isinstance(model, str):
...@@ -388,19 +400,6 @@ def pipeline( ...@@ -388,19 +400,6 @@ def pipeline(
# Retrieve use_auth_token and add it to model_kwargs to be used in .from_pretrained # Retrieve use_auth_token and add it to model_kwargs to be used in .from_pretrained
model_kwargs["use_auth_token"] = model_kwargs.get("use_auth_token", use_auth_token) model_kwargs["use_auth_token"] = model_kwargs.get("use_auth_token", use_auth_token)
# Instantiate tokenizer if needed
if isinstance(tokenizer, (str, tuple)):
if isinstance(tokenizer, tuple):
# For tuple we have (tokenizer name, {kwargs})
use_fast = tokenizer[1].pop("use_fast", use_fast)
tokenizer = AutoTokenizer.from_pretrained(
tokenizer[0], use_fast=use_fast, revision=revision, _from_pipeline=task, **tokenizer[1]
)
else:
tokenizer = AutoTokenizer.from_pretrained(
tokenizer, revision=revision, use_fast=use_fast, _from_pipeline=task, **model_kwargs
)
# Instantiate config if needed # Instantiate config if needed
if isinstance(config, str): if isinstance(config, str):
config = AutoConfig.from_pretrained(config, revision=revision, _from_pipeline=task, **model_kwargs) config = AutoConfig.from_pretrained(config, revision=revision, _from_pipeline=task, **model_kwargs)
...@@ -434,6 +433,61 @@ def pipeline( ...@@ -434,6 +433,61 @@ def pipeline(
model, config=config, revision=revision, _from_pipeline=task, **model_kwargs model, config=config, revision=revision, _from_pipeline=task, **model_kwargs
) )
model_config = model.config
load_tokenizer = type(model_config) in TOKENIZER_MAPPING
load_feature_extractor = type(model_config) in FEATURE_EXTRACTOR_MAPPING
if load_tokenizer:
# Try to infer tokenizer from model or config name (if provided as str)
if tokenizer is None:
if isinstance(model_name, str):
tokenizer = model_name
elif isinstance(config, str):
tokenizer = config
else:
# Impossible to guess what is the right tokenizer here
raise Exception(
"Impossible to guess which tokenizer to use. "
"Please provide a PreTrainedTokenizer class or a path/identifier to a pretrained tokenizer."
)
# Instantiate tokenizer if needed
if isinstance(tokenizer, (str, tuple)):
if isinstance(tokenizer, tuple):
# For tuple we have (tokenizer name, {kwargs})
use_fast = tokenizer[1].pop("use_fast", use_fast)
tokenizer_identifier = tokenizer[0]
tokenizer_kwargs = tokenizer[1]
else:
tokenizer_identifier = tokenizer
tokenizer_kwargs = model_kwargs
tokenizer = AutoTokenizer.from_pretrained(
tokenizer_identifier, revision=revision, use_fast=use_fast, _from_pipeline=task, **tokenizer_kwargs
)
if load_feature_extractor:
# Try to infer feature extractor from model or config name (if provided as str)
if feature_extractor is None:
if isinstance(model_name, str):
feature_extractor = model_name
elif isinstance(config, str):
feature_extractor = config
else:
# Impossible to guess what is the right feature_extractor here
raise Exception(
"Impossible to guess which feature extractor to use. "
"Please provide a PreTrainedFeatureExtractor class or a path/identifier "
"to a pretrained feature extractor."
)
# Instantiate feature_extractor if needed
if isinstance(feature_extractor, (str, tuple)):
feature_extractor = AutoFeatureExtractor.from_pretrained(
feature_extractor, revision=revision, _from_pipeline=task, **model_kwargs
)
if task == "translation" and model.config.task_specific_params: if task == "translation" and model.config.task_specific_params:
for key in model.config.task_specific_params: for key in model.config.task_specific_params:
if key.startswith("translation"): if key.startswith("translation"):
...@@ -444,4 +498,16 @@ def pipeline( ...@@ -444,4 +498,16 @@ def pipeline(
) )
break break
return task_class(model=model, tokenizer=tokenizer, modelcard=modelcard, framework=framework, task=task, **kwargs) if tokenizer is not None:
kwargs["tokenizer"] = tokenizer
if feature_extractor is not None:
kwargs["feature_extractor"] = feature_extractor
return task_class(
model=model,
modelcard=modelcard,
framework=framework,
task=task,
**kwargs,
)
...@@ -23,6 +23,7 @@ from contextlib import contextmanager ...@@ -23,6 +23,7 @@ from contextlib import contextmanager
from os.path import abspath, exists from os.path import abspath, exists
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
from ..feature_extraction_utils import PreTrainedFeatureExtractor
from ..file_utils import add_end_docstrings, is_tf_available, is_torch_available from ..file_utils import add_end_docstrings, is_tf_available, is_torch_available
from ..modelcard import ModelCard from ..modelcard import ModelCard
from ..tokenization_utils import PreTrainedTokenizer, TruncationStrategy from ..tokenization_utils import PreTrainedTokenizer, TruncationStrategy
...@@ -522,7 +523,8 @@ class Pipeline(_ScikitCompat): ...@@ -522,7 +523,8 @@ class Pipeline(_ScikitCompat):
def __init__( def __init__(
self, self,
model: Union["PreTrainedModel", "TFPreTrainedModel"], model: Union["PreTrainedModel", "TFPreTrainedModel"],
tokenizer: PreTrainedTokenizer, tokenizer: Optional[PreTrainedTokenizer] = None,
feature_extractor: Optional[PreTrainedFeatureExtractor] = None,
modelcard: Optional[ModelCard] = None, modelcard: Optional[ModelCard] = None,
framework: Optional[str] = None, framework: Optional[str] = None,
task: str = "", task: str = "",
...@@ -537,6 +539,7 @@ class Pipeline(_ScikitCompat): ...@@ -537,6 +539,7 @@ class Pipeline(_ScikitCompat):
self.task = task self.task = task
self.model = model self.model = model
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.feature_extractor = feature_extractor
self.modelcard = modelcard self.modelcard = modelcard
self.framework = framework self.framework = framework
self.device = device if framework == "tf" else torch.device("cpu" if device < 0 else f"cuda:{device}") self.device = device if framework == "tf" else torch.device("cpu" if device < 0 else f"cuda:{device}")
...@@ -565,7 +568,13 @@ class Pipeline(_ScikitCompat): ...@@ -565,7 +568,13 @@ class Pipeline(_ScikitCompat):
os.makedirs(save_directory, exist_ok=True) os.makedirs(save_directory, exist_ok=True)
self.model.save_pretrained(save_directory) self.model.save_pretrained(save_directory)
self.tokenizer.save_pretrained(save_directory)
if self.tokenizer is not None:
self.tokenizer.save_pretrained(save_directory)
if self.feature_extractor is not None:
self.feature_extractor.save_pretrained(save_directory)
if self.modelcard is not None: if self.modelcard is not None:
self.modelcard.save_pretrained(save_directory) self.modelcard.save_pretrained(save_directory)
...@@ -630,7 +639,14 @@ class Pipeline(_ScikitCompat): ...@@ -630,7 +639,14 @@ class Pipeline(_ScikitCompat):
The list of models supported by the pipeline, or a dictionary with model class values. The list of models supported by the pipeline, or a dictionary with model class values.
""" """
if not isinstance(supported_models, list): # Create from a model mapping if not isinstance(supported_models, list): # Create from a model mapping
supported_models = [item[1].__name__ for item in supported_models.items()] supported_models_names = []
for config, model in supported_models.items():
# Mapping can now contain tuples of models for the same configuration.
if isinstance(model, tuple):
supported_models_names.extend([_model.__name__ for _model in model])
else:
supported_models_names.append(model.__name__)
supported_models = supported_models_names
if self.model.__class__.__name__ not in supported_models: if self.model.__class__.__name__ not in supported_models:
raise PipelineException( raise PipelineException(
self.task, self.task,
......
import os
from typing import TYPE_CHECKING, List, Optional, Union
import requests
from ..feature_extraction_utils import PreTrainedFeatureExtractor
from ..file_utils import add_end_docstrings, is_torch_available, is_vision_available, requires_backends
from ..utils import logging
from .base import PIPELINE_INIT_ARGS, Pipeline
if TYPE_CHECKING:
from ..modeling_tf_utils import TFPreTrainedModel
from ..modeling_utils import PreTrainedModel
if is_vision_available():
from PIL import Image
if is_torch_available():
import torch
from ..models.auto.modeling_auto import MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING
logger = logging.get_logger(__name__)
@add_end_docstrings(PIPELINE_INIT_ARGS)
class ImageClassificationPipeline(Pipeline):
"""
Image classification pipeline using any :obj:`AutoModelForImageClassification`. This pipeline predicts the class of
an image.
This image classification pipeline can currently be loaded from :func:`~transformers.pipeline` using the following
task identifier: :obj:`"image-classification"`.
See the list of available models on `huggingface.co/models
<https://huggingface.co/models?filter=image-classification>`__.
"""
def __init__(
self,
model: Union["PreTrainedModel", "TFPreTrainedModel"],
feature_extractor: PreTrainedFeatureExtractor,
framework: Optional[str] = None,
**kwargs
):
super().__init__(model, feature_extractor=feature_extractor, framework=framework, **kwargs)
if self.framework == "tf":
raise ValueError(f"The {self.__class__} is only available in PyTorch.")
requires_backends(self, "vision")
self.check_model_type(MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING)
self.feature_extractor = feature_extractor
@staticmethod
def load_image(image: Union[str, "Image.Image"]):
if isinstance(image, str):
if image.startswith("http://") or image.startswith("https://"):
# We need to actually check for a real protocol, otherwise it's impossible to use a local file
# like http_huggingface_co.png
return Image.open(requests.get(image, stream=True).raw)
elif os.path.isfile(image):
return Image.open(image)
elif isinstance(image, Image.Image):
return image
raise ValueError(
"Incorrect format used for image. Should be an url linking to an image, a local path, or a PIL image."
)
def __call__(self, images: Union[str, List[str], "Image", List["Image"]], top_k=5):
"""
Assign labels to the image(s) passed as inputs.
Args:
images (:obj:`str`, :obj:`List[str]`, :obj:`PIL.Image` or :obj:`List[PIL.Image]`):
The pipeline handles three types of images:
- A string containing a http link pointing to an image
- A string containing a local path to an image
- An image loaded in PIL directly
The pipeline accepts either a single image or a batch of images, which must then be passed as a string.
Images in a batch must all be in the same format: all as http links, all as local paths, or all as PIL
images.
top_k (:obj:`int`, `optional`, defaults to 5):
The number of top labels that will be returned by the pipeline.
Return:
A dictionary or a list of dictionaries containing result. If the input is a single image, will return a
dictionary, if the input is a list of several images, will return a list of dictionaries corresponding to
the images.
The dictionaries contain the following keys:
- **label** (:obj:`str`) -- The label identified by the model.
- **score** (:obj:`int`) -- The score attributed by the model for that label.
"""
is_batched = isinstance(images, list)
if not is_batched:
images = [images]
images = [self.load_image(image) for image in images]
with torch.no_grad():
inputs = self.feature_extractor(images=images, return_tensors="pt")
outputs = self.model(**inputs)
probs = outputs.logits.softmax(-1)
scores, ids = probs.topk(top_k)
scores = scores.tolist()
ids = ids.tolist()
if not is_batched:
scores, ids = scores[0], ids[0]
labels = [{"score": score, "label": self.model.config.id2label[_id]} for score, _id in zip(scores, ids)]
else:
labels = []
for scores, ids in zip(scores, ids):
labels.append(
[{"score": score, "label": self.model.config.id2label[_id]} for score, _id in zip(scores, ids)]
)
return labels
...@@ -376,6 +376,15 @@ class AutoModelForCausalLM: ...@@ -376,6 +376,15 @@ class AutoModelForCausalLM:
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
class AutoModelForImageClassification:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_pretrained(self, *args, **kwargs):
requires_backends(self, ["torch"])
class AutoModelForMaskedLM: class AutoModelForMaskedLM:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
......
{
"feature_extractor_type": "Wav2Vec2FeatureExtractor"
}
\ No newline at end of file
...@@ -16,9 +16,10 @@ ...@@ -16,9 +16,10 @@
import os import os
import unittest import unittest
from transformers import FEATURE_EXTRACTOR_MAPPING, AutoFeatureExtractor, Wav2Vec2FeatureExtractor from transformers import AutoFeatureExtractor, Wav2Vec2FeatureExtractor
SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures")
SAMPLE_FEATURE_EXTRACTION_CONFIG = os.path.join( SAMPLE_FEATURE_EXTRACTION_CONFIG = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "fixtures/dummy_feature_extractor_config.json" os.path.dirname(os.path.abspath(__file__)), "fixtures/dummy_feature_extractor_config.json"
) )
...@@ -29,16 +30,10 @@ class AutoFeatureExtractorTest(unittest.TestCase): ...@@ -29,16 +30,10 @@ class AutoFeatureExtractorTest(unittest.TestCase):
config = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-base-960h") config = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-base-960h")
self.assertIsInstance(config, Wav2Vec2FeatureExtractor) self.assertIsInstance(config, Wav2Vec2FeatureExtractor)
def test_feature_extractor_from_local_directory(self):
config = AutoFeatureExtractor.from_pretrained(SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR)
self.assertIsInstance(config, Wav2Vec2FeatureExtractor)
def test_feature_extractor_from_local_file(self): def test_feature_extractor_from_local_file(self):
config = AutoFeatureExtractor.from_pretrained(SAMPLE_FEATURE_EXTRACTION_CONFIG) config = AutoFeatureExtractor.from_pretrained(SAMPLE_FEATURE_EXTRACTION_CONFIG)
self.assertIsInstance(config, Wav2Vec2FeatureExtractor) self.assertIsInstance(config, Wav2Vec2FeatureExtractor)
def test_pattern_matching_fallback(self):
"""
In cases where config.json doesn't include a model_type,
perform a few safety checks on the config mapping's order.
"""
# no key string should be included in a later key string (typical failure case)
keys = list(FEATURE_EXTRACTOR_MAPPING.keys())
for i, key in enumerate(keys):
self.assertFalse(any(key in later_key for later_key in keys[i + 1 :]))
# Copyright 2021 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from transformers import (
AutoFeatureExtractor,
AutoModelForImageClassification,
PreTrainedTokenizer,
is_vision_available,
)
from transformers.pipelines import ImageClassificationPipeline, pipeline
from transformers.testing_utils import require_torch, require_vision
if is_vision_available():
from PIL import Image
else:
class Image:
@staticmethod
def open(*args, **kwargs):
pass
@require_vision
@require_torch
class ImageClassificationPipelineTests(unittest.TestCase):
pipeline_task = "image-classification"
small_models = ["lysandre/tiny-vit-random"] # Models tested without the @slow decorator
valid_inputs = [
{"images": "http://images.cocodataset.org/val2017/000000039769.jpg"},
{
"images": [
"http://images.cocodataset.org/val2017/000000039769.jpg",
"http://images.cocodataset.org/val2017/000000039769.jpg",
]
},
{"images": "tests/fixtures/coco.jpg"},
{"images": ["tests/fixtures/coco.jpg", "tests/fixtures/coco.jpg"]},
{"images": Image.open("tests/fixtures/coco.jpg")},
{"images": [Image.open("tests/fixtures/coco.jpg"), Image.open("tests/fixtures/coco.jpg")]},
{"images": [Image.open("tests/fixtures/coco.jpg"), "tests/fixtures/coco.jpg"]},
]
def test_small_model_from_factory(self):
for small_model in self.small_models:
image_classifier = pipeline("image-classification", model=small_model)
for valid_input in self.valid_inputs:
output = image_classifier(**valid_input)
top_k = valid_input.get("top_k", 5)
def assert_valid_pipeline_output(pipeline_output):
self.assertTrue(isinstance(pipeline_output, list))
self.assertEqual(len(pipeline_output), top_k)
for label_result in pipeline_output:
self.assertTrue(isinstance(label_result, dict))
self.assertIn("label", label_result)
self.assertIn("score", label_result)
if isinstance(valid_input["images"], list):
self.assertEqual(len(valid_input["images"]), len(output))
for individual_output in output:
assert_valid_pipeline_output(individual_output)
else:
assert_valid_pipeline_output(output)
def test_small_model_from_pipeline(self):
for small_model in self.small_models:
model = AutoModelForImageClassification.from_pretrained(small_model)
feature_extractor = AutoFeatureExtractor.from_pretrained(small_model)
image_classifier = ImageClassificationPipeline(model=model, feature_extractor=feature_extractor)
for valid_input in self.valid_inputs:
output = image_classifier(**valid_input)
top_k = valid_input.get("top_k", 5)
def assert_valid_pipeline_output(pipeline_output):
self.assertTrue(isinstance(pipeline_output, list))
self.assertEqual(len(pipeline_output), top_k)
for label_result in pipeline_output:
self.assertTrue(isinstance(label_result, dict))
self.assertIn("label", label_result)
self.assertIn("score", label_result)
if isinstance(valid_input["images"], list):
# When images are batched, pipeline output is a list of lists of dictionaries
self.assertEqual(len(valid_input["images"]), len(output))
for individual_output in output:
assert_valid_pipeline_output(individual_output)
else:
# When images are batched, pipeline output is a list of dictionaries
assert_valid_pipeline_output(output)
def test_custom_tokenizer(self):
tokenizer = PreTrainedTokenizer()
# Assert that the pipeline can be initialized with a feature extractor that is not in any mapping
image_classifier = pipeline("image-classification", model=self.small_models[0], tokenizer=tokenizer)
self.assertIs(image_classifier.tokenizer, tokenizer)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment