"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "eb1a77bbb0a62d721e9a02e67b7e4f9e5afca08b"
Unverified Commit b2c477fc authored by Minghao Li's avatar Minghao Li Committed by GitHub
Browse files

support the trocr small models (#14893)



* support the trocr small models

* resolve conflict

* Update docs/source/model_doc/trocr.mdx
Co-authored-by: default avatarNielsRogge <48327001+NielsRogge@users.noreply.github.com>

* Update docs/source/model_doc/trocr.mdx
Co-authored-by: default avatarNielsRogge <48327001+NielsRogge@users.noreply.github.com>

* Update docs/source/model_doc/trocr.mdx
Co-authored-by: default avatarNielsRogge <48327001+NielsRogge@users.noreply.github.com>

* Update src/transformers/models/trocr/processing_trocr.py
Co-authored-by: default avatarNielsRogge <48327001+NielsRogge@users.noreply.github.com>

* Update src/transformers/models/trocr/processing_trocr.py
Co-authored-by: default avatarNielsRogge <48327001+NielsRogge@users.noreply.github.com>

* Update src/transformers/models/trocr/processing_trocr.py
Co-authored-by: default avatarNielsRogge <48327001+NielsRogge@users.noreply.github.com>

* Update src/transformers/models/trocr/processing_trocr.py
Co-authored-by: default avatarNielsRogge <48327001+NielsRogge@users.noreply.github.com>

* fix unexpected indent in processing_trocr.py

* Update src/transformers/models/trocr/processing_trocr.py
Co-authored-by: default avatarNielsRogge <48327001+NielsRogge@users.noreply.github.com>

* update the docstring of processing_trocr

* remove extra space
Co-authored-by: default avatarNielsRogge <48327001+NielsRogge@users.noreply.github.com>
parent 42d57549
......@@ -55,9 +55,9 @@ Tips:
TrOCR's [`VisionEncoderDecoder`] model accepts images as input and makes use of
[`~generation_utils.GenerationMixin.generate`] to autoregressively generate text given the input image.
The [`ViTFeatureExtractor`] class is responsible for preprocessing the input image and
[`RobertaTokenizer`] decodes the generated target tokens to the target string. The
[`TrOCRProcessor`] wraps [`ViTFeatureExtractor`] and [`RobertaTokenizer`]
The [`ViTFeatureExtractor`/`DeiTFeatureExtractor`] class is responsible for preprocessing the input image and
[`RobertaTokenizer`/`XLMRobertaTokenizer`] decodes the generated target tokens to the target string. The
[`TrOCRProcessor`] wraps [`ViTFeatureExtractor`/`DeiTFeatureExtractor`] and [`RobertaTokenizer`/`XLMRobertaTokenizer`]
into a single instance to both extract the input features and decode the predicted token ids.
- Step-by-step Optical Character Recognition (OCR)
......
......@@ -20,22 +20,24 @@ from contextlib import contextmanager
from transformers.feature_extraction_utils import FeatureExtractionMixin
from transformers.models.roberta.tokenization_roberta import RobertaTokenizer
from transformers.models.roberta.tokenization_roberta_fast import RobertaTokenizerFast
from transformers.models.xlm_roberta.tokenization_xlm_roberta import XLMRobertaTokenizer
from transformers.models.xlm_roberta.tokenization_xlm_roberta_fast import XLMRobertaTokenizerFast
from ..auto.feature_extraction_auto import AutoFeatureExtractor
from transformers import AutoTokenizer, AutoFeatureExtractor
class TrOCRProcessor:
r"""
Constructs a TrOCR processor which wraps a vision feature extractor and a TrOCR tokenizer into a single processor.
[`TrOCRProcessor`] offers all the functionalities of [`AutoFeatureExtractor`] and [`RobertaTokenizer`]. See the
[`TrOCRProcessor`] offers all the functionalities of [`ViTFeatureExtractor`/`DeiTFeatureExtractor`] and [`RobertaTokenizer`/`XLMRobertaTokenizer`]. See the
[`~TrOCRProcessor.__call__`] and [`~TrOCRProcessor.decode`] for more information.
Args:
feature_extractor ([`AutoFeatureExtractor`]):
An instance of [`AutoFeatureExtractor`]. The feature extractor is a required input.
tokenizer ([`RobertaTokenizer`]):
An instance of [`RobertaTokenizer`]. The tokenizer is a required input.
feature_extractor ([`ViTFeatureExtractor`/`DeiTFeatureExtractor`]):
An instance of [`ViTFeatureExtractor`/`DeiTFeatureExtractor`]. The feature extractor is a required input.
tokenizer ([`RobertaTokenizer`/`XLMRobertaTokenizer`]):
An instance of [`RobertaTokenizer`/`XLMRobertaTokenizer`]. The tokenizer is a required input.
"""
def __init__(self, feature_extractor, tokenizer):
......@@ -43,9 +45,9 @@ class TrOCRProcessor:
raise ValueError(
f"`feature_extractor` has to be of type {FeatureExtractionMixin.__class__}, but is {type(feature_extractor)}"
)
if not (isinstance(tokenizer, RobertaTokenizer) or (isinstance(tokenizer, RobertaTokenizerFast))):
if not isinstance(tokenizer, (RobertaTokenizer, RobertaTokenizerFast, XLMRobertaTokenizer, XLMRobertaTokenizerFast)):
raise ValueError(
f"`tokenizer` has to be of type {RobertaTokenizer.__class__} or {RobertaTokenizerFast.__class__}, but is {type(tokenizer)}"
f"`tokenizer` has to be of type {RobertaTokenizer.__class__} or {RobertaTokenizerFast.__class__} or {XLMRobertaTokenizer.__class__} or {XLMRobertaTokenizerFast.__class__}, but is {type(tokenizer)}"
)
self.feature_extractor = feature_extractor
......@@ -103,7 +105,7 @@ class TrOCRProcessor:
[`PreTrainedTokenizer`]
"""
feature_extractor = AutoFeatureExtractor.from_pretrained(pretrained_model_name_or_path, **kwargs)
tokenizer = RobertaTokenizer.from_pretrained(pretrained_model_name_or_path, **kwargs)
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path, **kwargs)
return cls(feature_extractor=feature_extractor, tokenizer=tokenizer)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment