"tests/vscode:/vscode.git/clone" did not exist on "8801861d2de1568e8ca8f81d96a7ddf3964f6373"
Unverified Commit 986526a0 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Replace `as_target` context managers by direct calls (#18325)



* Preliminary work on tokenizers

* Quality + fix tests

* Treat processors

* Fix pad

* Remove all uses of  in tests, docs and examples

* Replace all as_target_tokenizer

* Fix tests

* Fix quality

* Update examples/flax/image-captioning/run_image_captioning_flax.py
Co-authored-by: default avataramyeroberts <amy@huggingface.co>

* Style
Co-authored-by: default avataramyeroberts <amy@huggingface.co>
parent a64bcb56
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
# limitations under the License. # limitations under the License.
import os import os
from contextlib import contextmanager
from shutil import copyfile from shutil import copyfile
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
...@@ -98,10 +97,8 @@ class MBart50TokenizerFast(PreTrainedTokenizerFast): ...@@ -98,10 +97,8 @@ class MBart50TokenizerFast(PreTrainedTokenizerFast):
>>> tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50", src_lang="en_XX", tgt_lang="ro_RO") >>> tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50", src_lang="en_XX", tgt_lang="ro_RO")
>>> src_text = " UN Chief Says There Is No Military Solution in Syria" >>> src_text = " UN Chief Says There Is No Military Solution in Syria"
>>> tgt_text = "Şeful ONU declară că nu există o soluţie militară în Siria" >>> tgt_text = "Şeful ONU declară că nu există o soluţie militară în Siria"
>>> model_inputs = tokenizer(src_text, return_tensors="pt") >>> model_inputs = tokenizer(src_text, text_target=tgt_text, return_tensors="pt")
>>> with tokenizer.as_target_tokenizer(): >>> # model(**model_inputs) should work
... labels = tokenizer(tgt_text, return_tensors="pt").input_ids
>>> # model(**model_inputs, labels=labels) should work
```""" ```"""
vocab_files_names = VOCAB_FILES_NAMES vocab_files_names = VOCAB_FILES_NAMES
...@@ -211,15 +208,11 @@ class MBart50TokenizerFast(PreTrainedTokenizerFast): ...@@ -211,15 +208,11 @@ class MBart50TokenizerFast(PreTrainedTokenizerFast):
self.tgt_lang = tgt_lang self.tgt_lang = tgt_lang
return super().prepare_seq2seq_batch(src_texts, tgt_texts, **kwargs) return super().prepare_seq2seq_batch(src_texts, tgt_texts, **kwargs)
@contextmanager def _switch_to_input_mode(self):
def as_target_tokenizer(self): return self.set_src_lang_special_tokens(self.src_lang)
"""
Temporarily sets the tokenizer for encoding the targets. Useful for tokenizer associated to def _switch_to_target_mode(self):
sequence-to-sequence models that need a slightly different processing for the labels. return self.set_tgt_lang_special_tokens(self.tgt_lang)
"""
self.set_tgt_lang_special_tokens(self.tgt_lang)
yield
self.set_src_lang_special_tokens(self.src_lang)
def set_src_lang_special_tokens(self, src_lang: str) -> None: def set_src_lang_special_tokens(self, src_lang: str) -> None:
"""Reset the special tokens to the source lang setting. prefix=[src_lang_code] and suffix=[eos].""" """Reset the special tokens to the source lang setting. prefix=[src_lang_code] and suffix=[eos]."""
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
""" """
Speech processor class for M-CTC-T Speech processor class for M-CTC-T
""" """
import warnings
from contextlib import contextmanager from contextlib import contextmanager
from ...processing_utils import ProcessorMixin from ...processing_utils import ProcessorMixin
...@@ -39,6 +40,7 @@ class MCTCTProcessor(ProcessorMixin): ...@@ -39,6 +40,7 @@ class MCTCTProcessor(ProcessorMixin):
def __init__(self, feature_extractor, tokenizer): def __init__(self, feature_extractor, tokenizer):
super().__init__(feature_extractor, tokenizer) super().__init__(feature_extractor, tokenizer)
self.current_processor = self.feature_extractor self.current_processor = self.feature_extractor
self._in_target_context_manager = False
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
""" """
...@@ -47,8 +49,36 @@ class MCTCTProcessor(ProcessorMixin): ...@@ -47,8 +49,36 @@ class MCTCTProcessor(ProcessorMixin):
[`~MCTCTProcessor.as_target_processor`] this method forwards all its arguments to AutoTokenizer's [`~MCTCTProcessor.as_target_processor`] this method forwards all its arguments to AutoTokenizer's
[`~AutoTokenizer.__call__`]. Please refer to the doctsring of the above two methods for more information. [`~AutoTokenizer.__call__`]. Please refer to the doctsring of the above two methods for more information.
""" """
# For backward compatibility
if self._in_target_context_manager:
return self.current_processor(*args, **kwargs) return self.current_processor(*args, **kwargs)
if "raw_speech" in kwargs:
warnings.warn("Using `raw_speech` as a keyword argument is deprecated. Use `audio` instead.")
audio = kwargs.pop("raw_speech")
else:
audio = kwargs.pop("audio", None)
text = kwargs.pop("text", None)
if len(args) > 0:
audio = args[0]
args = args[1:]
if audio is None and text is None:
raise ValueError("You need to specify either an `audio` or `text` input to process.")
if audio is not None:
inputs = self.feature_extractor(audio, *args, **kwargs)
if text is not None:
encodings = self.tokenizer(text, **kwargs)
if text is None:
return inputs
elif audio is None:
return encodings
else:
inputs["labels"] = encodings["input_ids"]
return inputs
def batch_decode(self, *args, **kwargs): def batch_decode(self, *args, **kwargs):
""" """
This method forwards all its arguments to AutoTokenizer's [`~PreTrainedTokenizer.batch_decode`]. Please refer This method forwards all its arguments to AutoTokenizer's [`~PreTrainedTokenizer.batch_decode`]. Please refer
...@@ -63,8 +93,29 @@ class MCTCTProcessor(ProcessorMixin): ...@@ -63,8 +93,29 @@ class MCTCTProcessor(ProcessorMixin):
[`~MCTCTProcessor.as_target_processor`] this method forwards all its arguments to PreTrainedTokenizer's [`~MCTCTProcessor.as_target_processor`] this method forwards all its arguments to PreTrainedTokenizer's
[`~PreTrainedTokenizer.pad`]. Please refer to the docstring of the above two methods for more information. [`~PreTrainedTokenizer.pad`]. Please refer to the docstring of the above two methods for more information.
""" """
# For backward compatibility
if self._in_target_context_manager:
return self.current_processor.pad(*args, **kwargs) return self.current_processor.pad(*args, **kwargs)
input_features = kwargs.pop("input_features", None)
labels = kwargs.pop("labels", None)
if len(args) > 0:
input_features = args[0]
args = args[1:]
if input_features is not None:
input_features = self.feature_extractor.pad(input_features, *args, **kwargs)
if labels is not None:
labels = self.tokenizer.pad(labels, **kwargs)
if labels is None:
return input_features
elif input_features is None:
return labels
else:
input_features["labels"] = labels["input_ids"]
return input_features
def decode(self, *args, **kwargs): def decode(self, *args, **kwargs):
""" """
This method forwards all its arguments to AutoTokenizer's [`~PreTrainedTokenizer.decode`]. Please refer to the This method forwards all its arguments to AutoTokenizer's [`~PreTrainedTokenizer.decode`]. Please refer to the
...@@ -77,6 +128,13 @@ class MCTCTProcessor(ProcessorMixin): ...@@ -77,6 +128,13 @@ class MCTCTProcessor(ProcessorMixin):
""" """
Temporarily sets the tokenizer for processing the input. Useful for encoding the labels when fine-tuning MCTCT. Temporarily sets the tokenizer for processing the input. Useful for encoding the labels when fine-tuning MCTCT.
""" """
warnings.warn(
"`as_target_processor` is deprecated and will be removed in v5 of Transformers. You can process your "
"labels by using the argument `text` of the regular `__call__` method (either in the same call as "
"your audio inputs, or in a separate call."
)
self._in_target_context_manager = True
self.current_processor = self.tokenizer self.current_processor = self.tokenizer
yield yield
self.current_processor = self.feature_extractor self.current_processor = self.feature_extractor
self._in_target_context_manager = False
...@@ -57,8 +57,7 @@ class FlaxMT5Model(FlaxT5Model): ...@@ -57,8 +57,7 @@ class FlaxMT5Model(FlaxT5Model):
>>> summary = "Weiter Verhandlung in Syrien." >>> summary = "Weiter Verhandlung in Syrien."
>>> inputs = tokenizer(article, return_tensors="np") >>> inputs = tokenizer(article, return_tensors="np")
>>> with tokenizer.as_target_tokenizer(): >>> decoder_input_ids = tokenizer(text_target=summary, return_tensors="np").input_ids
... decoder_input_ids = tokenizer(summary, return_tensors="np").input_ids
>>> outputs = model(input_ids=inputs["input_ids"], decoder_input_ids=decoder_input_ids) >>> outputs = model(input_ids=inputs["input_ids"], decoder_input_ids=decoder_input_ids)
>>> hidden_states = outputs.last_hidden_state >>> hidden_states = outputs.last_hidden_state
...@@ -84,8 +83,7 @@ class FlaxMT5EncoderModel(FlaxT5EncoderModel): ...@@ -84,8 +83,7 @@ class FlaxMT5EncoderModel(FlaxT5EncoderModel):
>>> summary = "Weiter Verhandlung in Syrien." >>> summary = "Weiter Verhandlung in Syrien."
>>> inputs = tokenizer(article, return_tensors="np") >>> inputs = tokenizer(article, return_tensors="np")
>>> with tokenizer.as_target_tokenizer(): >>> decoder_input_ids = tokenizer(text_target=summary, return_tensors="np").input_ids
... decoder_input_ids = tokenizer(summary, return_tensors="np").input_ids
>>> outputs = model(input_ids=inputs["input_ids"]) >>> outputs = model(input_ids=inputs["input_ids"])
>>> hidden_states = outputs.last_hidden_state >>> hidden_states = outputs.last_hidden_state
...@@ -111,8 +109,7 @@ class FlaxMT5ForConditionalGeneration(FlaxT5ForConditionalGeneration): ...@@ -111,8 +109,7 @@ class FlaxMT5ForConditionalGeneration(FlaxT5ForConditionalGeneration):
>>> summary = "Weiter Verhandlung in Syrien." >>> summary = "Weiter Verhandlung in Syrien."
>>> inputs = tokenizer(article, return_tensors="np") >>> inputs = tokenizer(article, return_tensors="np")
>>> with tokenizer.as_target_tokenizer(): >>> decoder_input_ids = tokenizer(text_target=summary, return_tensors="np").input_ids
... decoder_input_ids = tokenizer(summary, return_tensors="np").input_ids
>>> outputs = model(**inputs, decoder_input_ids=decoder_input_ids) >>> outputs = model(**inputs, decoder_input_ids=decoder_input_ids)
>>> logits = outputs.logits >>> logits = outputs.logits
......
...@@ -40,8 +40,7 @@ class MT5Model(T5Model): ...@@ -40,8 +40,7 @@ class MT5Model(T5Model):
>>> article = "UN Offizier sagt, dass weiter verhandelt werden muss in Syrien." >>> article = "UN Offizier sagt, dass weiter verhandelt werden muss in Syrien."
>>> summary = "Weiter Verhandlung in Syrien." >>> summary = "Weiter Verhandlung in Syrien."
>>> inputs = tokenizer(article, return_tensors="pt") >>> inputs = tokenizer(article, return_tensors="pt")
>>> with tokenizer.as_target_tokenizer(): >>> labels = tokenizer(text_target=summary, return_tensors="pt")
... labels = tokenizer(summary, return_tensors="pt")
>>> outputs = model(input_ids=inputs["input_ids"], decoder_input_ids=labels["input_ids"]) >>> outputs = model(input_ids=inputs["input_ids"], decoder_input_ids=labels["input_ids"])
>>> hidden_states = outputs.last_hidden_state >>> hidden_states = outputs.last_hidden_state
...@@ -73,11 +72,9 @@ class MT5ForConditionalGeneration(T5ForConditionalGeneration): ...@@ -73,11 +72,9 @@ class MT5ForConditionalGeneration(T5ForConditionalGeneration):
>>> tokenizer = T5Tokenizer.from_pretrained("google/mt5-small") >>> tokenizer = T5Tokenizer.from_pretrained("google/mt5-small")
>>> article = "UN Offizier sagt, dass weiter verhandelt werden muss in Syrien." >>> article = "UN Offizier sagt, dass weiter verhandelt werden muss in Syrien."
>>> summary = "Weiter Verhandlung in Syrien." >>> summary = "Weiter Verhandlung in Syrien."
>>> inputs = tokenizer(article, return_tensors="pt") >>> inputs = tokenizer(article, text_target=summary, return_tensors="pt")
>>> with tokenizer.as_target_tokenizer():
... labels = tokenizer(summary, return_tensors="pt")
>>> outputs = model(**inputs, labels=labels["input_ids"]) >>> outputs = model(**inputs)
>>> loss = outputs.loss >>> loss = outputs.loss
```""" ```"""
......
...@@ -40,8 +40,7 @@ class TFMT5Model(TFT5Model): ...@@ -40,8 +40,7 @@ class TFMT5Model(TFT5Model):
>>> article = "UN Offizier sagt, dass weiter verhandelt werden muss in Syrien." >>> article = "UN Offizier sagt, dass weiter verhandelt werden muss in Syrien."
>>> summary = "Weiter Verhandlung in Syrien." >>> summary = "Weiter Verhandlung in Syrien."
>>> inputs = tokenizer(article, return_tensors="tf") >>> inputs = tokenizer(article, return_tensors="tf")
>>> with tokenizer.as_target_tokenizer(): >>> labels = tokenizer(text_target=summary, return_tensors="tf")
... labels = tokenizer(summary, return_tensors="tf")
>>> outputs = model(input_ids=inputs["input_ids"], decoder_input_ids=labels["input_ids"]) >>> outputs = model(input_ids=inputs["input_ids"], decoder_input_ids=labels["input_ids"])
>>> hidden_states = outputs.last_hidden_state >>> hidden_states = outputs.last_hidden_state
...@@ -64,11 +63,9 @@ class TFMT5ForConditionalGeneration(TFT5ForConditionalGeneration): ...@@ -64,11 +63,9 @@ class TFMT5ForConditionalGeneration(TFT5ForConditionalGeneration):
>>> tokenizer = T5Tokenizer.from_pretrained("google/mt5-small") >>> tokenizer = T5Tokenizer.from_pretrained("google/mt5-small")
>>> article = "UN Offizier sagt, dass weiter verhandelt werden muss in Syrien." >>> article = "UN Offizier sagt, dass weiter verhandelt werden muss in Syrien."
>>> summary = "Weiter Verhandlung in Syrien." >>> summary = "Weiter Verhandlung in Syrien."
>>> inputs = tokenizer(article, return_tensors="tf") >>> inputs = tokenizer(article, text_target=summary, return_tensors="tf")
>>> with tokenizer.as_target_tokenizer():
... labels = tokenizer(summary, return_tensors="tf")
>>> outputs = model(**inputs, labels=labels["input_ids"]) >>> outputs = model(**inputs)
>>> loss = outputs.loss >>> loss = outputs.loss
```""" ```"""
......
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
# limitations under the License. # limitations under the License.
import os import os
from contextlib import contextmanager
from shutil import copyfile from shutil import copyfile
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple
...@@ -67,10 +66,7 @@ class NllbTokenizer(PreTrainedTokenizer): ...@@ -67,10 +66,7 @@ class NllbTokenizer(PreTrainedTokenizer):
... ) ... )
>>> example_english_phrase = " UN Chief Says There Is No Military Solution in Syria" >>> example_english_phrase = " UN Chief Says There Is No Military Solution in Syria"
>>> expected_translation_french = "Le chef de l'ONU affirme qu'il n'y a pas de solution militaire en Syrie." >>> expected_translation_french = "Le chef de l'ONU affirme qu'il n'y a pas de solution militaire en Syrie."
>>> inputs = tokenizer(example_english_phrase, return_tensors="pt") >>> inputs = tokenizer(example_english_phrase, text_target=expected_translation_french, return_tensors="pt")
>>> with tokenizer.as_target_tokenizer():
... labels = tokenizer(expected_translation_french, return_tensors="pt")
>>> inputs["labels"] = labels["input_ids"]
``` ```
Args: Args:
...@@ -386,15 +382,11 @@ class NllbTokenizer(PreTrainedTokenizer): ...@@ -386,15 +382,11 @@ class NllbTokenizer(PreTrainedTokenizer):
self.tgt_lang = tgt_lang self.tgt_lang = tgt_lang
return super().prepare_seq2seq_batch(src_texts, tgt_texts, **kwargs) return super().prepare_seq2seq_batch(src_texts, tgt_texts, **kwargs)
@contextmanager def _switch_to_input_mode(self):
def as_target_tokenizer(self): return self.set_src_lang_special_tokens(self.src_lang)
"""
Temporarily sets the tokenizer for encoding the targets. Useful for tokenizer associated to def _switch_to_target_mode(self):
sequence-to-sequence models that need a slightly different processing for the labels. return self.set_tgt_lang_special_tokens(self.tgt_lang)
"""
self.set_tgt_lang_special_tokens(self.tgt_lang)
yield
self.set_src_lang_special_tokens(self.src_lang)
def set_src_lang_special_tokens(self, src_lang) -> None: def set_src_lang_special_tokens(self, src_lang) -> None:
"""Reset the special tokens to the source lang setting. No prefix and suffix=[eos, src_lang_code].""" """Reset the special tokens to the source lang setting. No prefix and suffix=[eos, src_lang_code]."""
......
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
# limitations under the License. # limitations under the License.
import os import os
from contextlib import contextmanager
from shutil import copyfile from shutil import copyfile
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
...@@ -80,10 +79,7 @@ class NllbTokenizerFast(PreTrainedTokenizerFast): ...@@ -80,10 +79,7 @@ class NllbTokenizerFast(PreTrainedTokenizerFast):
... ) ... )
>>> example_english_phrase = " UN Chief Says There Is No Military Solution in Syria" >>> example_english_phrase = " UN Chief Says There Is No Military Solution in Syria"
>>> expected_translation_french = "Le chef de l'ONU affirme qu'il n'y a pas de solution militaire en Syrie." >>> expected_translation_french = "Le chef de l'ONU affirme qu'il n'y a pas de solution militaire en Syrie."
>>> inputs = tokenizer(example_english_phrase, return_tensors="pt") >>> inputs = tokenizer(example_english_phrase, text_target=expected_translation_french, return_tensors="pt")
>>> with tokenizer.as_target_tokenizer():
... labels = tokenizer(expected_translation_french, return_tensors="pt")
>>> inputs["labels"] = labels["input_ids"]
``` ```
Args: Args:
...@@ -284,15 +280,11 @@ class NllbTokenizerFast(PreTrainedTokenizerFast): ...@@ -284,15 +280,11 @@ class NllbTokenizerFast(PreTrainedTokenizerFast):
self.tgt_lang = tgt_lang self.tgt_lang = tgt_lang
return super().prepare_seq2seq_batch(src_texts, tgt_texts, **kwargs) return super().prepare_seq2seq_batch(src_texts, tgt_texts, **kwargs)
@contextmanager def _switch_to_input_mode(self):
def as_target_tokenizer(self): return self.set_src_lang_special_tokens(self.src_lang)
"""
Temporarily sets the tokenizer for encoding the targets. Useful for tokenizer associated to def _switch_to_target_mode(self):
sequence-to-sequence models that need a slightly different processing for the labels. return self.set_tgt_lang_special_tokens(self.tgt_lang)
"""
self.set_tgt_lang_special_tokens(self.tgt_lang)
yield
self.set_src_lang_special_tokens(self.src_lang)
def set_src_lang_special_tokens(self, src_lang) -> None: def set_src_lang_special_tokens(self, src_lang) -> None:
"""Reset the special tokens to the source lang setting. No prefix and suffix=[eos, src_lang_code].""" """Reset the special tokens to the source lang setting. No prefix and suffix=[eos, src_lang_code]."""
......
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
# limitations under the License. # limitations under the License.
import os import os
from contextlib import contextmanager
from shutil import copyfile from shutil import copyfile
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple
...@@ -153,10 +152,7 @@ class PLBartTokenizer(PreTrainedTokenizer): ...@@ -153,10 +152,7 @@ class PLBartTokenizer(PreTrainedTokenizer):
>>> tokenizer = PLBartTokenizer.from_pretrained("uclanlp/plbart-python-en_XX", src_lang="python", tgt_lang="en_XX") >>> tokenizer = PLBartTokenizer.from_pretrained("uclanlp/plbart-python-en_XX", src_lang="python", tgt_lang="en_XX")
>>> example_python_phrase = "def maximum(a,b,c):NEW_LINE_INDENTreturn max([a,b,c])" >>> example_python_phrase = "def maximum(a,b,c):NEW_LINE_INDENTreturn max([a,b,c])"
>>> expected_translation_english = "Returns the maximum value of a b c." >>> expected_translation_english = "Returns the maximum value of a b c."
>>> inputs = tokenizer(example_python_phrase, return_tensors="pt") >>> inputs = tokenizer(example_python_phrase, text_target=expected_translation_english, return_tensors="pt")
>>> with tokenizer.as_target_tokenizer():
... labels = tokenizer(expected_translation_english, return_tensors="pt")
>>> inputs["labels"] = labels["input_ids"]
```""" ```"""
vocab_files_names = VOCAB_FILES_NAMES vocab_files_names = VOCAB_FILES_NAMES
...@@ -441,15 +437,11 @@ class PLBartTokenizer(PreTrainedTokenizer): ...@@ -441,15 +437,11 @@ class PLBartTokenizer(PreTrainedTokenizer):
self.tgt_lang = tgt_lang self.tgt_lang = tgt_lang
return super().prepare_seq2seq_batch(src_texts, tgt_texts, **kwargs) return super().prepare_seq2seq_batch(src_texts, tgt_texts, **kwargs)
@contextmanager def _switch_to_input_mode(self):
def as_target_tokenizer(self): return self.set_src_lang_special_tokens(self.src_lang)
"""
Temporarily sets the tokenizer for encoding the targets. Useful for tokenizer associated to def _switch_to_target_mode(self):
sequence-to-sequence models that need a slightly different processing for the labels. return self.set_tgt_lang_special_tokens(self.tgt_lang)
"""
self.set_tgt_lang_special_tokens(self.tgt_lang)
yield
self.set_src_lang_special_tokens(self.src_lang)
def set_src_lang_special_tokens(self, src_lang) -> None: def set_src_lang_special_tokens(self, src_lang) -> None:
"""Reset the special tokens to the source lang setting. No prefix and suffix=[eos, src_lang_code].""" """Reset the special tokens to the source lang setting. No prefix and suffix=[eos, src_lang_code]."""
......
...@@ -818,8 +818,7 @@ class RagSequenceForGeneration(RagPreTrainedModel): ...@@ -818,8 +818,7 @@ class RagSequenceForGeneration(RagPreTrainedModel):
>>> model = RagSequenceForGeneration.from_pretrained("facebook/rag-token-nq", retriever=retriever) >>> model = RagSequenceForGeneration.from_pretrained("facebook/rag-token-nq", retriever=retriever)
>>> inputs = tokenizer("How many people live in Paris?", return_tensors="pt") >>> inputs = tokenizer("How many people live in Paris?", return_tensors="pt")
>>> with tokenizer.as_target_tokenizer(): >>> targets = tokenizer(text_target="In Paris, there are 10 million people.", return_tensors="pt")
... targets = tokenizer("In Paris, there are 10 million people.", return_tensors="pt")
>>> input_ids = inputs["input_ids"] >>> input_ids = inputs["input_ids"]
>>> labels = targets["input_ids"] >>> labels = targets["input_ids"]
>>> outputs = model(input_ids=input_ids, labels=labels) >>> outputs = model(input_ids=input_ids, labels=labels)
...@@ -1287,8 +1286,7 @@ class RagTokenForGeneration(RagPreTrainedModel): ...@@ -1287,8 +1286,7 @@ class RagTokenForGeneration(RagPreTrainedModel):
>>> model = RagTokenForGeneration.from_pretrained("facebook/rag-token-nq", retriever=retriever) >>> model = RagTokenForGeneration.from_pretrained("facebook/rag-token-nq", retriever=retriever)
>>> inputs = tokenizer("How many people live in Paris?", return_tensors="pt") >>> inputs = tokenizer("How many people live in Paris?", return_tensors="pt")
>>> with tokenizer.as_target_tokenizer(): >>> targets = tokenizer(text_target="In Paris, there are 10 million people.", return_tensors="pt")
... targets = tokenizer("In Paris, there are 10 million people.", return_tensors="pt")
>>> input_ids = inputs["input_ids"] >>> input_ids = inputs["input_ids"]
>>> labels = targets["input_ids"] >>> labels = targets["input_ids"]
>>> outputs = model(input_ids=input_ids, labels=labels) >>> outputs = model(input_ids=input_ids, labels=labels)
......
...@@ -15,7 +15,6 @@ ...@@ -15,7 +15,6 @@
"""Tokenization classes for RAG.""" """Tokenization classes for RAG."""
import os import os
import warnings import warnings
from contextlib import contextmanager
from typing import List, Optional from typing import List, Optional
from ...tokenization_utils_base import BatchEncoding from ...tokenization_utils_base import BatchEncoding
...@@ -68,16 +67,12 @@ class RagTokenizer: ...@@ -68,16 +67,12 @@ class RagTokenizer:
def decode(self, *args, **kwargs): def decode(self, *args, **kwargs):
return self.generator.decode(*args, **kwargs) return self.generator.decode(*args, **kwargs)
@contextmanager def _switch_to_input_mode(self):
def as_target_tokenizer(self):
"""
Temporarily sets the tokenizer for encoding the targets. Useful for tokenizer associated to
sequence-to-sequence models that need a slightly different processing for the labels.
"""
self.current_tokenizer = self.generator
yield
self.current_tokenizer = self.question_encoder self.current_tokenizer = self.question_encoder
def _switch_to_target_mode(self):
self.current_tokenizer = self.generator
def prepare_seq2seq_batch( def prepare_seq2seq_batch(
self, self,
src_texts: List[str], src_texts: List[str],
...@@ -110,11 +105,10 @@ class RagTokenizer: ...@@ -110,11 +105,10 @@ class RagTokenizer:
if tgt_texts is None: if tgt_texts is None:
return model_inputs return model_inputs
# Process tgt_texts # Process tgt_texts
with self.as_target_tokenizer():
if max_target_length is None: if max_target_length is None:
max_target_length = self.current_tokenizer.model_max_length max_target_length = self.current_tokenizer.model_max_length
labels = self( labels = self(
tgt_texts, text_target=tgt_texts,
add_special_tokens=True, add_special_tokens=True,
return_tensors=return_tensors, return_tensors=return_tensors,
padding=padding, padding=padding,
......
...@@ -482,8 +482,7 @@ class SpeechEncoderDecoderModel(PreTrainedModel): ...@@ -482,8 +482,7 @@ class SpeechEncoderDecoderModel(PreTrainedModel):
'Mr. Quilter ist der Apostel der Mittelschicht und wir freuen uns, sein Evangelium willkommen heißen zu können.' 'Mr. Quilter ist der Apostel der Mittelschicht und wir freuen uns, sein Evangelium willkommen heißen zu können.'
>>> # Training: Train model on English transcription >>> # Training: Train model on English transcription
>>> with processor.as_target_processor(): >>> labels = processor(text=ds[0]["text"], return_tensors="pt").input_ids
... labels = processor(ds[0]["text"], return_tensors="pt").input_ids
>>> loss = model(input_values, labels=labels).loss >>> loss = model(input_values, labels=labels).loss
>>> loss.backward() >>> loss.backward()
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
""" """
Speech processor class for Speech2Text Speech processor class for Speech2Text
""" """
import warnings
from contextlib import contextmanager from contextlib import contextmanager
from ...processing_utils import ProcessorMixin from ...processing_utils import ProcessorMixin
...@@ -41,6 +42,7 @@ class Speech2TextProcessor(ProcessorMixin): ...@@ -41,6 +42,7 @@ class Speech2TextProcessor(ProcessorMixin):
def __init__(self, feature_extractor, tokenizer): def __init__(self, feature_extractor, tokenizer):
super().__init__(feature_extractor, tokenizer) super().__init__(feature_extractor, tokenizer)
self.current_processor = self.feature_extractor self.current_processor = self.feature_extractor
self._in_target_context_manager = False
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
""" """
...@@ -50,8 +52,36 @@ class Speech2TextProcessor(ProcessorMixin): ...@@ -50,8 +52,36 @@ class Speech2TextProcessor(ProcessorMixin):
[`~Speech2TextTokenizer.__call__`]. Please refer to the doctsring of the above two methods for more [`~Speech2TextTokenizer.__call__`]. Please refer to the doctsring of the above two methods for more
information. information.
""" """
# For backward compatibility
if self._in_target_context_manager:
return self.current_processor(*args, **kwargs) return self.current_processor(*args, **kwargs)
if "raw_speech" in kwargs:
warnings.warn("Using `raw_speech` as a keyword argument is deprecated. Use `audio` instead.")
audio = kwargs.pop("raw_speech")
else:
audio = kwargs.pop("audio", None)
text = kwargs.pop("text", None)
if len(args) > 0:
audio = args[0]
args = args[1:]
if audio is None and text is None:
raise ValueError("You need to specify either an `audio` or `text` input to process.")
if audio is not None:
inputs = self.feature_extractor(audio, *args, **kwargs)
if text is not None:
encodings = self.tokenizer(text, **kwargs)
if text is None:
return inputs
elif audio is None:
return encodings
else:
inputs["labels"] = encodings["input_ids"]
return inputs
def batch_decode(self, *args, **kwargs): def batch_decode(self, *args, **kwargs):
""" """
This method forwards all its arguments to Speech2TextTokenizer's [`~PreTrainedTokenizer.batch_decode`]. Please This method forwards all its arguments to Speech2TextTokenizer's [`~PreTrainedTokenizer.batch_decode`]. Please
...@@ -72,6 +102,13 @@ class Speech2TextProcessor(ProcessorMixin): ...@@ -72,6 +102,13 @@ class Speech2TextProcessor(ProcessorMixin):
Temporarily sets the tokenizer for processing the input. Useful for encoding the labels when fine-tuning Temporarily sets the tokenizer for processing the input. Useful for encoding the labels when fine-tuning
Speech2Text. Speech2Text.
""" """
warnings.warn(
"`as_target_processor` is deprecated and will be removed in v5 of Transformers. You can process your "
"labels by using the argument `text` of the regular `__call__` method (either in the same call as "
"your audio inputs, or in a separate call."
)
self._in_target_context_manager = True
self.current_processor = self.tokenizer self.current_processor = self.tokenizer
yield yield
self.current_processor = self.feature_extractor self.current_processor = self.feature_extractor
self._in_target_context_manager = False
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
""" """
Speech processor class for Speech2Text2 Speech processor class for Speech2Text2
""" """
import warnings
from contextlib import contextmanager from contextlib import contextmanager
from ...processing_utils import ProcessorMixin from ...processing_utils import ProcessorMixin
...@@ -40,6 +41,7 @@ class Speech2Text2Processor(ProcessorMixin): ...@@ -40,6 +41,7 @@ class Speech2Text2Processor(ProcessorMixin):
def __init__(self, feature_extractor, tokenizer): def __init__(self, feature_extractor, tokenizer):
super().__init__(feature_extractor, tokenizer) super().__init__(feature_extractor, tokenizer)
self.current_processor = self.feature_extractor self.current_processor = self.feature_extractor
self._in_target_context_manager = False
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
""" """
...@@ -49,8 +51,36 @@ class Speech2Text2Processor(ProcessorMixin): ...@@ -49,8 +51,36 @@ class Speech2Text2Processor(ProcessorMixin):
Speech2Text2Tokenizer's [`~Speech2Text2Tokenizer.__call__`]. Please refer to the doctsring of the above two Speech2Text2Tokenizer's [`~Speech2Text2Tokenizer.__call__`]. Please refer to the doctsring of the above two
methods for more information. methods for more information.
""" """
# For backward compatibility
if self._in_target_context_manager:
return self.current_processor(*args, **kwargs) return self.current_processor(*args, **kwargs)
if "raw_speech" in kwargs:
warnings.warn("Using `raw_speech` as a keyword argument is deprecated. Use `audio` instead.")
audio = kwargs.pop("raw_speech")
else:
audio = kwargs.pop("audio", None)
text = kwargs.pop("text", None)
if len(args) > 0:
audio = args[0]
args = args[1:]
if audio is None and text is None:
raise ValueError("You need to specify either an `audio` or `text` input to process.")
if audio is not None:
inputs = self.feature_extractor(audio, *args, **kwargs)
if text is not None:
encodings = self.tokenizer(text, **kwargs)
if text is None:
return inputs
elif audio is None:
return encodings
else:
inputs["labels"] = encodings["input_ids"]
return inputs
def batch_decode(self, *args, **kwargs): def batch_decode(self, *args, **kwargs):
""" """
This method forwards all its arguments to Speech2Text2Tokenizer's [`~PreTrainedTokenizer.batch_decode`]. Please This method forwards all its arguments to Speech2Text2Tokenizer's [`~PreTrainedTokenizer.batch_decode`]. Please
...@@ -71,6 +101,13 @@ class Speech2Text2Processor(ProcessorMixin): ...@@ -71,6 +101,13 @@ class Speech2Text2Processor(ProcessorMixin):
Temporarily sets the tokenizer for processing the input. Useful for encoding the labels when fine-tuning Temporarily sets the tokenizer for processing the input. Useful for encoding the labels when fine-tuning
Speech2Text2. Speech2Text2.
""" """
warnings.warn(
"`as_target_processor` is deprecated and will be removed in v5 of Transformers. You can process your "
"labels by using the argument `text` of the regular `__call__` method (either in the same call as "
"your audio inputs, or in a separate call."
)
self._in_target_context_manager = True
self.current_processor = self.tokenizer self.current_processor = self.tokenizer
yield yield
self.current_processor = self.feature_extractor self.current_processor = self.feature_extractor
self._in_target_context_manager = False
...@@ -17,7 +17,6 @@ ...@@ -17,7 +17,6 @@
import json import json
import os import os
import random import random
from contextlib import contextmanager
from functools import lru_cache from functools import lru_cache
from typing import Dict, List, Optional, Tuple, Union from typing import Dict, List, Optional, Tuple, Union
...@@ -63,12 +62,6 @@ class TapexTruncationStrategy(ExplicitEnum): ...@@ -63,12 +62,6 @@ class TapexTruncationStrategy(ExplicitEnum):
DROP_ROWS_TO_FIT = "drop_rows_to_fit" DROP_ROWS_TO_FIT = "drop_rows_to_fit"
class TokenizerStrategy(ExplicitEnum):
TOKENIZE_SOURCE = "tokenize_source"
TOKENIZE_TARGET = "tokenize_target"
TAPEX_ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING = r""" TAPEX_ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING = r"""
add_special_tokens (`bool`, *optional*, defaults to `True`): add_special_tokens (`bool`, *optional*, defaults to `True`):
Whether or not to encode the sequences with the special tokens relative to their model. Whether or not to encode the sequences with the special tokens relative to their model.
...@@ -341,9 +334,6 @@ class TapexTokenizer(PreTrainedTokenizer): ...@@ -341,9 +334,6 @@ class TapexTokenizer(PreTrainedTokenizer):
self.max_cell_length = max_cell_length self.max_cell_length = max_cell_length
self.table_linearize = IndexedRowTableLinearize() self.table_linearize = IndexedRowTableLinearize()
# property to decide using which call function
self.current_tokenizer = TokenizerStrategy.TOKENIZE_SOURCE
def build_inputs_with_special_tokens( def build_inputs_with_special_tokens(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
) -> List[int]: ) -> List[int]:
...@@ -555,9 +545,7 @@ class TapexTokenizer(PreTrainedTokenizer): ...@@ -555,9 +545,7 @@ class TapexTokenizer(PreTrainedTokenizer):
Optionally, the corresponding answer to the questions as supervision. Optionally, the corresponding answer to the questions as supervision.
""" """
if self.current_tokenizer == TokenizerStrategy.TOKENIZE_SOURCE: if table is not None:
if table is None:
raise ValueError("Please ensure that the table is not empty if you use TAPEX to encode source.")
return self.source_call_func( return self.source_call_func(
table=table, table=table,
query=query, query=query,
...@@ -578,9 +566,7 @@ class TapexTokenizer(PreTrainedTokenizer): ...@@ -578,9 +566,7 @@ class TapexTokenizer(PreTrainedTokenizer):
verbose=verbose, verbose=verbose,
**kwargs, **kwargs,
) )
else: elif answer is not None:
if answer is None:
raise ValueError("Please ensure that the answer is not empty if you use TAPEX to encode target.")
return self.target_call_func( return self.target_call_func(
answer=answer, answer=answer,
add_special_tokens=add_special_tokens, add_special_tokens=add_special_tokens,
...@@ -599,6 +585,8 @@ class TapexTokenizer(PreTrainedTokenizer): ...@@ -599,6 +585,8 @@ class TapexTokenizer(PreTrainedTokenizer):
verbose=verbose, verbose=verbose,
**kwargs, **kwargs,
) )
else:
raise ValueError("You need to provide either a `table` or an `answer`.")
def source_call_func( def source_call_func(
self, self,
...@@ -1330,17 +1318,6 @@ class TapexTokenizer(PreTrainedTokenizer): ...@@ -1330,17 +1318,6 @@ class TapexTokenizer(PreTrainedTokenizer):
verbose=verbose, verbose=verbose,
) )
@contextmanager
def as_target_tokenizer(self):
"""
Temporarily sets the tokenizer for encoding the targets. Useful for tokenizer associated to
sequence-to-sequence models that need a slightly different processing for the labels.
"""
self.current_tokenizer = TokenizerStrategy.TOKENIZE_TARGET
yield
# restore the call function
self.current_tokenizer = TokenizerStrategy.TOKENIZE_SOURCE
def prepare_table_query( def prepare_table_query(
self, self,
table, table,
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
""" """
Processor class for TrOCR. Processor class for TrOCR.
""" """
import warnings
from contextlib import contextmanager from contextlib import contextmanager
from ...processing_utils import ProcessorMixin from ...processing_utils import ProcessorMixin
...@@ -40,6 +41,7 @@ class TrOCRProcessor(ProcessorMixin): ...@@ -40,6 +41,7 @@ class TrOCRProcessor(ProcessorMixin):
def __init__(self, feature_extractor, tokenizer): def __init__(self, feature_extractor, tokenizer):
super().__init__(feature_extractor, tokenizer) super().__init__(feature_extractor, tokenizer)
self.current_processor = self.feature_extractor self.current_processor = self.feature_extractor
self._in_target_context_manager = False
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
""" """
...@@ -48,8 +50,36 @@ class TrOCRProcessor(ProcessorMixin): ...@@ -48,8 +50,36 @@ class TrOCRProcessor(ProcessorMixin):
[`~TrOCRProcessor.as_target_processor`] this method forwards all its arguments to TrOCRTokenizer's [`~TrOCRProcessor.as_target_processor`] this method forwards all its arguments to TrOCRTokenizer's
[`~TrOCRTokenizer.__call__`]. Please refer to the doctsring of the above two methods for more information. [`~TrOCRTokenizer.__call__`]. Please refer to the doctsring of the above two methods for more information.
""" """
# For backward compatibility
if self._in_target_context_manager:
return self.current_processor(*args, **kwargs) return self.current_processor(*args, **kwargs)
if "raw_speech" in kwargs:
warnings.warn("Using `raw_speech` as a keyword argument is deprecated. Use `audio` instead.")
audio = kwargs.pop("raw_speech")
else:
audio = kwargs.pop("audio", None)
text = kwargs.pop("text", None)
if len(args) > 0:
audio = args[0]
args = args[1:]
if audio is None and text is None:
raise ValueError("You need to specify either an `audio` or `text` input to process.")
if audio is not None:
inputs = self.feature_extractor(audio, *args, **kwargs)
if text is not None:
encodings = self.tokenizer(text, **kwargs)
if text is None:
return inputs
elif audio is None:
return encodings
else:
inputs["labels"] = encodings["input_ids"]
return inputs
def batch_decode(self, *args, **kwargs): def batch_decode(self, *args, **kwargs):
""" """
This method forwards all its arguments to TrOCRTokenizer's [`~PreTrainedTokenizer.batch_decode`]. Please refer This method forwards all its arguments to TrOCRTokenizer's [`~PreTrainedTokenizer.batch_decode`]. Please refer
...@@ -69,6 +99,13 @@ class TrOCRProcessor(ProcessorMixin): ...@@ -69,6 +99,13 @@ class TrOCRProcessor(ProcessorMixin):
""" """
Temporarily sets the tokenizer for processing the input. Useful for encoding the labels when fine-tuning TrOCR. Temporarily sets the tokenizer for processing the input. Useful for encoding the labels when fine-tuning TrOCR.
""" """
warnings.warn(
"`as_target_processor` is deprecated and will be removed in v5 of Transformers. You can process your "
"labels by using the argument `text` of the regular `__call__` method (either in the same call as "
"your audio inputs, or in a separate call."
)
self._in_target_context_manager = True
self.current_processor = self.tokenizer self.current_processor = self.tokenizer
yield yield
self.current_processor = self.feature_extractor self.current_processor = self.feature_extractor
self._in_target_context_manager = False
...@@ -1650,9 +1650,8 @@ class TFWav2Vec2ForCTC(TFWav2Vec2PreTrainedModel): ...@@ -1650,9 +1650,8 @@ class TFWav2Vec2ForCTC(TFWav2Vec2PreTrainedModel):
>>> # compute loss >>> # compute loss
>>> target_transcription = "A MAN SAID TO THE UNIVERSE SIR I EXIST" >>> target_transcription = "A MAN SAID TO THE UNIVERSE SIR I EXIST"
>>> # wrap processor as target processor to encode labels >>> # Pass transcription as `text` to encode labels
>>> with processor.as_target_processor(): >>> labels = processor(text=transcription, return_tensors="tf").input_ids
... labels = processor(transcription, return_tensors="tf").input_ids
>>> loss = model(input_values, labels=labels).loss >>> loss = model(input_values, labels=labels).loss
```""" ```"""
......
...@@ -43,6 +43,7 @@ class Wav2Vec2Processor(ProcessorMixin): ...@@ -43,6 +43,7 @@ class Wav2Vec2Processor(ProcessorMixin):
def __init__(self, feature_extractor, tokenizer): def __init__(self, feature_extractor, tokenizer):
super().__init__(feature_extractor, tokenizer) super().__init__(feature_extractor, tokenizer)
self.current_processor = self.feature_extractor self.current_processor = self.feature_extractor
self._in_target_context_manager = False
@classmethod @classmethod
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
...@@ -70,8 +71,36 @@ class Wav2Vec2Processor(ProcessorMixin): ...@@ -70,8 +71,36 @@ class Wav2Vec2Processor(ProcessorMixin):
[`~Wav2Vec2Processor.as_target_processor`] this method forwards all its arguments to PreTrainedTokenizer's [`~Wav2Vec2Processor.as_target_processor`] this method forwards all its arguments to PreTrainedTokenizer's
[`~PreTrainedTokenizer.__call__`]. Please refer to the docstring of the above two methods for more information. [`~PreTrainedTokenizer.__call__`]. Please refer to the docstring of the above two methods for more information.
""" """
# For backward compatibility
if self._in_target_context_manager:
return self.current_processor(*args, **kwargs) return self.current_processor(*args, **kwargs)
if "raw_speech" in kwargs:
warnings.warn("Using `raw_speech` as a keyword argument is deprecated. Use `audio` instead.")
audio = kwargs.pop("raw_speech")
else:
audio = kwargs.pop("audio", None)
text = kwargs.pop("text", None)
if len(args) > 0:
audio = args[0]
args = args[1:]
if audio is None and text is None:
raise ValueError("You need to specify either an `audio` or `text` input to process.")
if audio is not None:
inputs = self.feature_extractor(audio, *args, **kwargs)
if text is not None:
encodings = self.tokenizer(text, **kwargs)
if text is None:
return inputs
elif audio is None:
return encodings
else:
inputs["labels"] = encodings["input_ids"]
return inputs
def pad(self, *args, **kwargs): def pad(self, *args, **kwargs):
""" """
When used in normal mode, this method forwards all its arguments to Wav2Vec2FeatureExtractor's When used in normal mode, this method forwards all its arguments to Wav2Vec2FeatureExtractor's
...@@ -79,8 +108,29 @@ class Wav2Vec2Processor(ProcessorMixin): ...@@ -79,8 +108,29 @@ class Wav2Vec2Processor(ProcessorMixin):
[`~Wav2Vec2Processor.as_target_processor`] this method forwards all its arguments to PreTrainedTokenizer's [`~Wav2Vec2Processor.as_target_processor`] this method forwards all its arguments to PreTrainedTokenizer's
[`~PreTrainedTokenizer.pad`]. Please refer to the docstring of the above two methods for more information. [`~PreTrainedTokenizer.pad`]. Please refer to the docstring of the above two methods for more information.
""" """
# For backward compatibility
if self._in_target_context_manager:
return self.current_processor.pad(*args, **kwargs) return self.current_processor.pad(*args, **kwargs)
input_features = kwargs.pop("input_features", None)
labels = kwargs.pop("labels", None)
if len(args) > 0:
input_features = args[0]
args = args[1:]
if input_features is not None:
input_features = self.feature_extractor.pad(input_features, *args, **kwargs)
if labels is not None:
labels = self.tokenizer.pad(labels, **kwargs)
if labels is None:
return input_features
elif input_features is None:
return labels
else:
input_features["labels"] = labels["input_ids"]
return input_features
def batch_decode(self, *args, **kwargs): def batch_decode(self, *args, **kwargs):
""" """
This method forwards all its arguments to PreTrainedTokenizer's [`~PreTrainedTokenizer.batch_decode`]. Please This method forwards all its arguments to PreTrainedTokenizer's [`~PreTrainedTokenizer.batch_decode`]. Please
...@@ -101,6 +151,13 @@ class Wav2Vec2Processor(ProcessorMixin): ...@@ -101,6 +151,13 @@ class Wav2Vec2Processor(ProcessorMixin):
Temporarily sets the tokenizer for processing the input. Useful for encoding the labels when fine-tuning Temporarily sets the tokenizer for processing the input. Useful for encoding the labels when fine-tuning
Wav2Vec2. Wav2Vec2.
""" """
warnings.warn(
"`as_target_processor` is deprecated and will be removed in v5 of Transformers. You can process your "
"labels by using the argument `text` of the regular `__call__` method (either in the same call as "
"your audio inputs, or in a separate call."
)
self._in_target_context_manager = True
self.current_processor = self.tokenizer self.current_processor = self.tokenizer
yield yield
self.current_processor = self.feature_extractor self.current_processor = self.feature_extractor
self._in_target_context_manager = False
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
Speech processor class for Wav2Vec2 Speech processor class for Wav2Vec2
""" """
import os import os
import warnings
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass from dataclasses import dataclass
from multiprocessing import get_context from multiprocessing import get_context
...@@ -99,6 +100,7 @@ class Wav2Vec2ProcessorWithLM(ProcessorMixin): ...@@ -99,6 +100,7 @@ class Wav2Vec2ProcessorWithLM(ProcessorMixin):
self.decoder = decoder self.decoder = decoder
self.current_processor = self.feature_extractor self.current_processor = self.feature_extractor
self._in_target_context_manager = False
def save_pretrained(self, save_directory): def save_pretrained(self, save_directory):
super().save_pretrained(save_directory) super().save_pretrained(save_directory)
...@@ -214,8 +216,36 @@ class Wav2Vec2ProcessorWithLM(ProcessorMixin): ...@@ -214,8 +216,36 @@ class Wav2Vec2ProcessorWithLM(ProcessorMixin):
Wav2Vec2CTCTokenizer's [`~Wav2Vec2CTCTokenizer.__call__`]. Please refer to the docstring of the above two Wav2Vec2CTCTokenizer's [`~Wav2Vec2CTCTokenizer.__call__`]. Please refer to the docstring of the above two
methods for more information. methods for more information.
""" """
# For backward compatibility
if self._in_target_context_manager:
return self.current_processor(*args, **kwargs) return self.current_processor(*args, **kwargs)
if "raw_speech" in kwargs:
warnings.warn("Using `raw_speech` as a keyword argument is deprecated. Use `audio` instead.")
audio = kwargs.pop("raw_speech")
else:
audio = kwargs.pop("audio", None)
text = kwargs.pop("text", None)
if len(args) > 0:
audio = args[0]
args = args[1:]
if audio is None and text is None:
raise ValueError("You need to specify either an `audio` or `text` input to process.")
if audio is not None:
inputs = self.feature_extractor(audio, *args, **kwargs)
if text is not None:
encodings = self.tokenizer(text, **kwargs)
if text is None:
return inputs
elif audio is None:
return encodings
else:
inputs["labels"] = encodings["input_ids"]
return inputs
def pad(self, *args, **kwargs): def pad(self, *args, **kwargs):
""" """
When used in normal mode, this method forwards all its arguments to Wav2Vec2FeatureExtractor's When used in normal mode, this method forwards all its arguments to Wav2Vec2FeatureExtractor's
...@@ -224,8 +254,29 @@ class Wav2Vec2ProcessorWithLM(ProcessorMixin): ...@@ -224,8 +254,29 @@ class Wav2Vec2ProcessorWithLM(ProcessorMixin):
Wav2Vec2CTCTokenizer's [`~Wav2Vec2CTCTokenizer.pad`]. Please refer to the docstring of the above two methods Wav2Vec2CTCTokenizer's [`~Wav2Vec2CTCTokenizer.pad`]. Please refer to the docstring of the above two methods
for more information. for more information.
""" """
# For backward compatibility
if self._in_target_context_manager:
return self.current_processor.pad(*args, **kwargs) return self.current_processor.pad(*args, **kwargs)
input_features = kwargs.pop("input_features", None)
labels = kwargs.pop("labels", None)
if len(args) > 0:
input_features = args[0]
args = args[1:]
if input_features is not None:
input_features = self.feature_extractor.pad(input_features, *args, **kwargs)
if labels is not None:
labels = self.tokenizer.pad(labels, **kwargs)
if labels is None:
return input_features
elif input_features is None:
return labels
else:
input_features["labels"] = labels["input_ids"]
return input_features
def batch_decode( def batch_decode(
self, self,
logits: np.ndarray, logits: np.ndarray,
...@@ -486,9 +537,16 @@ class Wav2Vec2ProcessorWithLM(ProcessorMixin): ...@@ -486,9 +537,16 @@ class Wav2Vec2ProcessorWithLM(ProcessorMixin):
@contextmanager @contextmanager
def as_target_processor(self): def as_target_processor(self):
""" """
Temporarily sets the tokenizer for processing the input. Useful for encoding the labels when fine-tuning Temporarily sets the processor for processing the target. Useful for encoding the labels when fine-tuning
Wav2Vec2. Wav2Vec2.
""" """
warnings.warn(
"`as_target_processor` is deprecated and will be removed in v5 of Transformers. You can process your "
"labels by using the argument `text` of the regular `__call__` method (either in the same call as "
"your audio inputs, or in a separate call."
)
self._in_target_context_manager = True
self.current_processor = self.tokenizer self.current_processor = self.tokenizer
yield yield
self.current_processor = self.feature_extractor self.current_processor = self.feature_extractor
self._in_target_context_manager = False
...@@ -1501,7 +1501,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin): ...@@ -1501,7 +1501,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
self.deprecation_warnings = ( self.deprecation_warnings = (
{} {}
) # Use to store when we have already noticed a deprecation warning (avoid overlogging). ) # Use to store when we have already noticed a deprecation warning (avoid overlogging).
self._in_target_context_manager = False
super().__init__(**kwargs) super().__init__(**kwargs)
@property @property
...@@ -2431,8 +2431,12 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin): ...@@ -2431,8 +2431,12 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
@add_end_docstrings(ENCODE_KWARGS_DOCSTRING, ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING) @add_end_docstrings(ENCODE_KWARGS_DOCSTRING, ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING)
def __call__( def __call__(
self, self,
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]], text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
text_pair: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None, text_pair: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None,
text_target: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
text_pair_target: Optional[
Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]
] = None,
add_special_tokens: bool = True, add_special_tokens: bool = True,
padding: Union[bool, str, PaddingStrategy] = False, padding: Union[bool, str, PaddingStrategy] = False,
truncation: Union[bool, str, TruncationStrategy] = False, truncation: Union[bool, str, TruncationStrategy] = False,
...@@ -2455,15 +2459,85 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin): ...@@ -2455,15 +2459,85 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
sequences. sequences.
Args: Args:
text (`str`, `List[str]`, `List[List[str]]`): text (`str`, `List[str]`, `List[List[str]]`, *optional*):
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences). `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
text_pair (`str`, `List[str]`, `List[List[str]]`): text_pair (`str`, `List[str]`, `List[List[str]]`, *optional*):
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences). `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
""" text_target (`str`, `List[str]`, `List[List[str]]`, *optional*):
The sequence or batch of sequences to be encoded as target texts. Each sequence can be a string or a
list of strings (pretokenized string). If the sequences are provided as list of strings (pretokenized),
you must set `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
text_pair_target (`str`, `List[str]`, `List[List[str]]`, *optional*):
The sequence or batch of sequences to be encoded as target texts. Each sequence can be a string or a
list of strings (pretokenized string). If the sequences are provided as list of strings (pretokenized),
you must set `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
"""
# To avoid duplicating
all_kwargs = dict(
add_special_tokens=add_special_tokens,
padding=padding,
truncation=truncation,
max_length=max_length,
stride=stride,
is_split_into_words=is_split_into_words,
pad_to_multiple_of=pad_to_multiple_of,
return_tensors=return_tensors,
return_token_type_ids=return_token_type_ids,
return_attention_mask=return_attention_mask,
return_overflowing_tokens=return_overflowing_tokens,
return_special_tokens_mask=return_special_tokens_mask,
return_offsets_mapping=return_offsets_mapping,
return_length=return_length,
verbose=verbose,
)
all_kwargs.update(kwargs)
if text is None and text_target is None:
raise ValueError("You need to specify either `text` or `text_target`.")
if text is not None:
# The context manager will send the inputs as normal texts and not text_target, but we shouldn't change the
# input mode in this case.
if not self._in_target_context_manager:
self._switch_to_input_mode()
encodings = self._call_one(text=text, text_pair=text_pair, **all_kwargs)
if text_target is not None:
self._switch_to_target_mode()
target_encodings = self._call_one(text=text_target, text_pair=text_pair_target, **all_kwargs)
# Leave back tokenizer in input mode
self._switch_to_input_mode()
if text_target is None:
return encodings
elif text is None:
return target_encodings
else:
encodings["labels"] = target_encodings["input_ids"]
return encodings
def _call_one(
self,
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]],
text_pair: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None,
add_special_tokens: bool = True,
padding: Union[bool, str, PaddingStrategy] = False,
truncation: Union[bool, str, TruncationStrategy] = False,
max_length: Optional[int] = None,
stride: int = 0,
is_split_into_words: bool = False,
pad_to_multiple_of: Optional[int] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
return_token_type_ids: Optional[bool] = None,
return_attention_mask: Optional[bool] = None,
return_overflowing_tokens: bool = False,
return_special_tokens_mask: bool = False,
return_offsets_mapping: bool = False,
return_length: bool = False,
verbose: bool = True,
**kwargs
) -> BatchEncoding:
# Input type checking for clearer error # Input type checking for clearer error
def _is_valid_text_input(t): def _is_valid_text_input(t):
if isinstance(t, str): if isinstance(t, str):
...@@ -3456,13 +3530,34 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin): ...@@ -3456,13 +3530,34 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
) )
self.deprecation_warnings["sequence-length-is-longer-than-the-specified-maximum"] = True self.deprecation_warnings["sequence-length-is-longer-than-the-specified-maximum"] = True
def _switch_to_input_mode(self):
"""
Private method to put the tokenizer in input mode (when it has different modes for input/outputs)
"""
pass
def _switch_to_target_mode(self):
"""
Private method to put the tokenizer in target mode (when it has different modes for input/outputs)
"""
pass
@contextmanager @contextmanager
def as_target_tokenizer(self): def as_target_tokenizer(self):
""" """
Temporarily sets the tokenizer for encoding the targets. Useful for tokenizer associated to Temporarily sets the tokenizer for encoding the targets. Useful for tokenizer associated to
sequence-to-sequence models that need a slightly different processing for the labels. sequence-to-sequence models that need a slightly different processing for the labels.
""" """
warnings.warn(
"`as_target_tokenizer` is deprecated and will be removed in v5 of Transformers. You can tokenize your "
"labels by using the argument `text_target` of the regular `__call__` method (either in the same call as "
"your input texts if you use the same keyword arguments, or in a separate call."
)
self._switch_to_target_mode()
self._in_target_context_manager = True
yield yield
self._in_target_context_manager = False
self._switch_to_input_mode()
@classmethod @classmethod
def register_for_auto_class(cls, auto_class="AutoTokenizer"): def register_for_auto_class(cls, auto_class="AutoTokenizer"):
...@@ -3563,14 +3658,17 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin): ...@@ -3563,14 +3658,17 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
# docstyle-ignore # docstyle-ignore
formatted_warning = """ formatted_warning = """
`prepare_seq2seq_batch` is deprecated and will be removed in version 5 of HuggingFace Transformers. Use the regular `prepare_seq2seq_batch` is deprecated and will be removed in version 5 of HuggingFace Transformers. Use the regular
`__call__` method to prepare your inputs and the tokenizer under the `as_target_tokenizer` context manager to prepare `__call__` method to prepare your inputs and targets.
your targets.
Here is a short example: Here is a short example:
model_inputs = tokenizer(src_texts, text_target=tgt_texts, ...)
If you either need to use different keyword arguments for the source and target texts, you should do two calls like
this:
model_inputs = tokenizer(src_texts, ...) model_inputs = tokenizer(src_texts, ...)
with tokenizer.as_target_tokenizer(): labels = tokenizer(text_target=tgt_texts, ...)
labels = tokenizer(tgt_texts, ...)
model_inputs["labels"] = labels["input_ids"] model_inputs["labels"] = labels["input_ids"]
See the documentation of your specific tokenizer for more details on the specific arguments to the tokenizer of choice. See the documentation of your specific tokenizer for more details on the specific arguments to the tokenizer of choice.
......
...@@ -428,8 +428,7 @@ PT_SPEECH_CTC_SAMPLE = r""" ...@@ -428,8 +428,7 @@ PT_SPEECH_CTC_SAMPLE = r"""
``` ```
```python ```python
>>> with processor.as_target_processor(): >>> inputs["labels"] = processor(text=dataset[0]["text"], return_tensors="pt").input_ids
... inputs["labels"] = processor(dataset[0]["text"], return_tensors="pt").input_ids
>>> # compute loss >>> # compute loss
>>> loss = model(**inputs).loss >>> loss = model(**inputs).loss
...@@ -849,8 +848,7 @@ TF_SPEECH_CTC_SAMPLE = r""" ...@@ -849,8 +848,7 @@ TF_SPEECH_CTC_SAMPLE = r"""
``` ```
```python ```python
>>> with processor.as_target_processor(): >>> inputs["labels"] = processor(text=dataset[0]["text"], return_tensors="tf").input_ids
... inputs["labels"] = processor(dataset[0]["text"], return_tensors="tf").input_ids
>>> # compute loss >>> # compute loss
>>> loss = model(**inputs).loss >>> loss = model(**inputs).loss
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment