Unverified Commit b2c477fc authored by Minghao Li's avatar Minghao Li Committed by GitHub
Browse files

support the trocr small models (#14893)



* support the trocr small models

* resolve conflict

* Update docs/source/model_doc/trocr.mdx
Co-authored-by: default avatarNielsRogge <48327001+NielsRogge@users.noreply.github.com>

* Update docs/source/model_doc/trocr.mdx
Co-authored-by: default avatarNielsRogge <48327001+NielsRogge@users.noreply.github.com>

* Update docs/source/model_doc/trocr.mdx
Co-authored-by: default avatarNielsRogge <48327001+NielsRogge@users.noreply.github.com>

* Update src/transformers/models/trocr/processing_trocr.py
Co-authored-by: default avatarNielsRogge <48327001+NielsRogge@users.noreply.github.com>

* Update src/transformers/models/trocr/processing_trocr.py
Co-authored-by: default avatarNielsRogge <48327001+NielsRogge@users.noreply.github.com>

* Update src/transformers/models/trocr/processing_trocr.py
Co-authored-by: default avatarNielsRogge <48327001+NielsRogge@users.noreply.github.com>

* Update src/transformers/models/trocr/processing_trocr.py
Co-authored-by: default avatarNielsRogge <48327001+NielsRogge@users.noreply.github.com>

* fix unexpected indent in processing_trocr.py

* Update src/transformers/models/trocr/processing_trocr.py
Co-authored-by: default avatarNielsRogge <48327001+NielsRogge@users.noreply.github.com>

* update the docstring of processing_trocr

* remove extra space
Co-authored-by: default avatarNielsRogge <48327001+NielsRogge@users.noreply.github.com>
parent 42d57549
...@@ -55,9 +55,9 @@ Tips: ...@@ -55,9 +55,9 @@ Tips:
TrOCR's [`VisionEncoderDecoder`] model accepts images as input and makes use of TrOCR's [`VisionEncoderDecoder`] model accepts images as input and makes use of
[`~generation_utils.GenerationMixin.generate`] to autoregressively generate text given the input image. [`~generation_utils.GenerationMixin.generate`] to autoregressively generate text given the input image.
The [`ViTFeatureExtractor`] class is responsible for preprocessing the input image and The [`ViTFeatureExtractor`/`DeiTFeatureExtractor`] class is responsible for preprocessing the input image and
[`RobertaTokenizer`] decodes the generated target tokens to the target string. The [`RobertaTokenizer`/`XLMRobertaTokenizer`] decodes the generated target tokens to the target string. The
[`TrOCRProcessor`] wraps [`ViTFeatureExtractor`] and [`RobertaTokenizer`] [`TrOCRProcessor`] wraps [`ViTFeatureExtractor`/`DeiTFeatureExtractor`] and [`RobertaTokenizer`/`XLMRobertaTokenizer`]
into a single instance to both extract the input features and decode the predicted token ids. into a single instance to both extract the input features and decode the predicted token ids.
- Step-by-step Optical Character Recognition (OCR) - Step-by-step Optical Character Recognition (OCR)
......
...@@ -20,22 +20,24 @@ from contextlib import contextmanager ...@@ -20,22 +20,24 @@ from contextlib import contextmanager
from transformers.feature_extraction_utils import FeatureExtractionMixin from transformers.feature_extraction_utils import FeatureExtractionMixin
from transformers.models.roberta.tokenization_roberta import RobertaTokenizer from transformers.models.roberta.tokenization_roberta import RobertaTokenizer
from transformers.models.roberta.tokenization_roberta_fast import RobertaTokenizerFast from transformers.models.roberta.tokenization_roberta_fast import RobertaTokenizerFast
from transformers.models.xlm_roberta.tokenization_xlm_roberta import XLMRobertaTokenizer
from transformers.models.xlm_roberta.tokenization_xlm_roberta_fast import XLMRobertaTokenizerFast
from ..auto.feature_extraction_auto import AutoFeatureExtractor from transformers import AutoTokenizer, AutoFeatureExtractor
class TrOCRProcessor: class TrOCRProcessor:
r""" r"""
Constructs a TrOCR processor which wraps a vision feature extractor and a TrOCR tokenizer into a single processor. Constructs a TrOCR processor which wraps a vision feature extractor and a TrOCR tokenizer into a single processor.
[`TrOCRProcessor`] offers all the functionalities of [`AutoFeatureExtractor`] and [`RobertaTokenizer`]. See the [`TrOCRProcessor`] offers all the functionalities of [`ViTFeatureExtractor`/`DeiTFeatureExtractor`] and [`RobertaTokenizer`/`XLMRobertaTokenizer`]. See the
[`~TrOCRProcessor.__call__`] and [`~TrOCRProcessor.decode`] for more information. [`~TrOCRProcessor.__call__`] and [`~TrOCRProcessor.decode`] for more information.
Args: Args:
feature_extractor ([`AutoFeatureExtractor`]): feature_extractor ([`ViTFeatureExtractor`/`DeiTFeatureExtractor`]):
An instance of [`AutoFeatureExtractor`]. The feature extractor is a required input. An instance of [`ViTFeatureExtractor`/`DeiTFeatureExtractor`]. The feature extractor is a required input.
tokenizer ([`RobertaTokenizer`]): tokenizer ([`RobertaTokenizer`/`XLMRobertaTokenizer`]):
An instance of [`RobertaTokenizer`]. The tokenizer is a required input. An instance of [`RobertaTokenizer`/`XLMRobertaTokenizer`]. The tokenizer is a required input.
""" """
def __init__(self, feature_extractor, tokenizer): def __init__(self, feature_extractor, tokenizer):
...@@ -43,9 +45,9 @@ class TrOCRProcessor: ...@@ -43,9 +45,9 @@ class TrOCRProcessor:
raise ValueError( raise ValueError(
f"`feature_extractor` has to be of type {FeatureExtractionMixin.__class__}, but is {type(feature_extractor)}" f"`feature_extractor` has to be of type {FeatureExtractionMixin.__class__}, but is {type(feature_extractor)}"
) )
if not (isinstance(tokenizer, RobertaTokenizer) or (isinstance(tokenizer, RobertaTokenizerFast))): if not isinstance(tokenizer, (RobertaTokenizer, RobertaTokenizerFast, XLMRobertaTokenizer, XLMRobertaTokenizerFast)):
raise ValueError( raise ValueError(
f"`tokenizer` has to be of type {RobertaTokenizer.__class__} or {RobertaTokenizerFast.__class__}, but is {type(tokenizer)}" f"`tokenizer` has to be of type {RobertaTokenizer.__class__} or {RobertaTokenizerFast.__class__} or {XLMRobertaTokenizer.__class__} or {XLMRobertaTokenizerFast.__class__}, but is {type(tokenizer)}"
) )
self.feature_extractor = feature_extractor self.feature_extractor = feature_extractor
...@@ -103,7 +105,7 @@ class TrOCRProcessor: ...@@ -103,7 +105,7 @@ class TrOCRProcessor:
[`PreTrainedTokenizer`] [`PreTrainedTokenizer`]
""" """
feature_extractor = AutoFeatureExtractor.from_pretrained(pretrained_model_name_or_path, **kwargs) feature_extractor = AutoFeatureExtractor.from_pretrained(pretrained_model_name_or_path, **kwargs)
tokenizer = RobertaTokenizer.from_pretrained(pretrained_model_name_or_path, **kwargs) tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path, **kwargs)
return cls(feature_extractor=feature_extractor, tokenizer=tokenizer) return cls(feature_extractor=feature_extractor, tokenizer=tokenizer)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment