Unverified Commit 986526a0 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Replace `as_target` context managers by direct calls (#18325)



* Preliminary work on tokenizers

* Quality + fix tests

* Treat processors

* Fix pad

* Remove all uses of  in tests, docs and examples

* Replace all as_target_tokenizer

* Fix tests

* Fix quality

* Update examples/flax/image-captioning/run_image_captioning_flax.py
Co-authored-by: default avataramyeroberts <amy@huggingface.co>

* Style
Co-authored-by: default avataramyeroberts <amy@huggingface.co>
parent a64bcb56
...@@ -305,13 +305,12 @@ class DataCollatorCTCWithPadding: ...@@ -305,13 +305,12 @@ class DataCollatorCTCWithPadding:
return_tensors="pt", return_tensors="pt",
) )
with self.processor.as_target_processor(): labels_batch = self.processor.pad(
labels_batch = self.processor.pad( labels=label_features,
label_features, padding=self.padding,
padding=self.padding, pad_to_multiple_of=self.pad_to_multiple_of_labels,
pad_to_multiple_of=self.pad_to_multiple_of_labels, return_tensors="pt",
return_tensors="pt", )
)
# replace padding with -100 to ignore loss correctly # replace padding with -100 to ignore loss correctly
labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100) labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
......
...@@ -522,9 +522,8 @@ def main(): ...@@ -522,9 +522,8 @@ def main():
inputs = [prefix + inp for inp in inputs] inputs = [prefix + inp for inp in inputs]
model_inputs = tokenizer(inputs, max_length=data_args.max_source_length, padding=padding, truncation=True) model_inputs = tokenizer(inputs, max_length=data_args.max_source_length, padding=padding, truncation=True)
# Setup the tokenizer for targets # Tokenize targets with the `text_target` keyword argument
with tokenizer.as_target_tokenizer(): labels = tokenizer(text_target=targets, max_length=max_target_length, padding=padding, truncation=True)
labels = tokenizer(targets, max_length=max_target_length, padding=padding, truncation=True)
# If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore # If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore
# padding in the loss. # padding in the loss.
......
...@@ -470,9 +470,8 @@ def main(): ...@@ -470,9 +470,8 @@ def main():
inputs = [prefix + inp for inp in inputs] inputs = [prefix + inp for inp in inputs]
model_inputs = tokenizer(inputs, max_length=args.max_source_length, padding=padding, truncation=True) model_inputs = tokenizer(inputs, max_length=args.max_source_length, padding=padding, truncation=True)
# Setup the tokenizer for targets # Tokenize targets with the `text_target` keyword argument
with tokenizer.as_target_tokenizer(): labels = tokenizer(text_target=targets, max_length=max_target_length, padding=padding, truncation=True)
labels = tokenizer(targets, max_length=max_target_length, padding=padding, truncation=True)
# If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore # If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore
# padding in the loss. # padding in the loss.
......
...@@ -443,9 +443,8 @@ def main(): ...@@ -443,9 +443,8 @@ def main():
inputs = [prefix + inp for inp in inputs] inputs = [prefix + inp for inp in inputs]
model_inputs = tokenizer(inputs, max_length=data_args.max_source_length, padding=padding, truncation=True) model_inputs = tokenizer(inputs, max_length=data_args.max_source_length, padding=padding, truncation=True)
# Setup the tokenizer for targets # Tokenize targets with the `text_target` keyword argument
with tokenizer.as_target_tokenizer(): labels = tokenizer(text_target=targets, max_length=max_target_length, padding=padding, truncation=True)
labels = tokenizer(targets, max_length=max_target_length, padding=padding, truncation=True)
# If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore # If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore
# padding in the loss. # padding in the loss.
......
...@@ -452,9 +452,8 @@ def main(): ...@@ -452,9 +452,8 @@ def main():
inputs = [prefix + inp for inp in inputs] inputs = [prefix + inp for inp in inputs]
model_inputs = tokenizer(inputs, max_length=args.max_source_length, padding=padding, truncation=True) model_inputs = tokenizer(inputs, max_length=args.max_source_length, padding=padding, truncation=True)
# Setup the tokenizer for targets # Tokenize targets with the `text_target` keyword argument
with tokenizer.as_target_tokenizer(): labels = tokenizer(text_target=targets, max_length=max_target_length, padding=padding, truncation=True)
labels = tokenizer(targets, max_length=max_target_length, padding=padding, truncation=True)
# If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore # If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore
# padding in the loss. # padding in the loss.
......
...@@ -304,13 +304,12 @@ class DataCollatorCTCWithPadding: ...@@ -304,13 +304,12 @@ class DataCollatorCTCWithPadding:
return_tensors="pt", return_tensors="pt",
) )
with self.processor.as_target_processor(): labels_batch = self.processor.pad(
labels_batch = self.processor.pad( labels=label_features,
label_features, padding=self.padding,
padding=self.padding, pad_to_multiple_of=self.pad_to_multiple_of_labels,
pad_to_multiple_of=self.pad_to_multiple_of_labels, return_tensors="pt",
return_tensors="pt", )
)
# replace padding with -100 to ignore loss correctly # replace padding with -100 to ignore loss correctly
labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100) labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
......
...@@ -301,13 +301,12 @@ class DataCollatorCTCWithPadding: ...@@ -301,13 +301,12 @@ class DataCollatorCTCWithPadding:
return_tensors="pt", return_tensors="pt",
) )
with self.processor.as_target_processor(): labels_batch = self.processor.pad(
labels_batch = self.processor.pad( labels=label_features,
label_features, padding=self.padding,
padding=self.padding, pad_to_multiple_of=self.pad_to_multiple_of_labels,
pad_to_multiple_of=self.pad_to_multiple_of_labels, return_tensors="pt",
return_tensors="pt", )
)
# replace padding with -100 to ignore loss correctly # replace padding with -100 to ignore loss correctly
labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100) labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
......
...@@ -437,13 +437,12 @@ def main(): ...@@ -437,13 +437,12 @@ def main():
table=tables, query=questions, max_length=data_args.max_source_length, padding=padding, truncation=True table=tables, query=questions, max_length=data_args.max_source_length, padding=padding, truncation=True
) )
with tokenizer.as_target_tokenizer(): labels = tokenizer(
labels = tokenizer( answer=[", ".join(answer) for answer in answers],
answer=[", ".join(answer) for answer in answers], max_length=max_target_length,
max_length=max_target_length, padding=padding,
padding=padding, truncation=True,
truncation=True, )
)
# If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore # If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore
# padding in the loss. # padding in the loss.
......
...@@ -413,13 +413,12 @@ def main(): ...@@ -413,13 +413,12 @@ def main():
table=tables, query=questions, max_length=data_args.max_source_length, padding=padding, truncation=True table=tables, query=questions, max_length=data_args.max_source_length, padding=padding, truncation=True
) )
with tokenizer.as_target_tokenizer(): labels = tokenizer(
labels = tokenizer( answer=[", ".join(answer) for answer in answers],
answer=[", ".join(answer) for answer in answers], max_length=max_target_length,
max_length=max_target_length, padding=padding,
padding=padding, truncation=True,
truncation=True, )
)
# If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore # If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore
# padding in the loss. # padding in the loss.
......
...@@ -266,14 +266,13 @@ class DataCollatorCTCWithPadding: ...@@ -266,14 +266,13 @@ class DataCollatorCTCWithPadding:
pad_to_multiple_of=self.pad_to_multiple_of, pad_to_multiple_of=self.pad_to_multiple_of,
return_tensors="pt", return_tensors="pt",
) )
with self.processor.as_target_processor(): labels_batch = self.processor.pad(
labels_batch = self.processor.pad( labels=label_features,
label_features, padding=self.padding,
padding=self.padding, max_length=self.max_length_labels,
max_length=self.max_length_labels, pad_to_multiple_of=self.pad_to_multiple_of_labels,
pad_to_multiple_of=self.pad_to_multiple_of_labels, return_tensors="pt",
return_tensors="pt", )
)
# replace padding with -100 to ignore loss correctly # replace padding with -100 to ignore loss correctly
labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100) labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
...@@ -419,9 +418,10 @@ def main(): ...@@ -419,9 +418,10 @@ def main():
len(set(batch["sampling_rate"])) == 1 len(set(batch["sampling_rate"])) == 1
), f"Make sure all inputs have the same sampling rate of {processor.feature_extractor.sampling_rate}." ), f"Make sure all inputs have the same sampling rate of {processor.feature_extractor.sampling_rate}."
batch["input_values"] = processor(batch["speech"], sampling_rate=batch["sampling_rate"][0]).input_values processed_batch = processor(
with processor.as_target_processor(): audio=batch["speech"], text=batch[data_args.target_text_column], sampling_rate=batch["sampling_rate"][0]
batch["labels"] = processor(batch[data_args.target_text_column]).input_ids )
batch.update(processed_batch)
return batch return batch
train_dataset = train_dataset.map( train_dataset = train_dataset.map(
......
...@@ -185,14 +185,13 @@ class DataCollatorCTCWithPadding: ...@@ -185,14 +185,13 @@ class DataCollatorCTCWithPadding:
pad_to_multiple_of=self.pad_to_multiple_of, pad_to_multiple_of=self.pad_to_multiple_of,
return_tensors="pt", return_tensors="pt",
) )
with self.processor.as_target_processor(): labels_batch = self.processor.pad(
labels_batch = self.processor.pad( labels=label_features,
label_features, padding=self.padding,
padding=self.padding, max_length=self.max_length_labels,
max_length=self.max_length_labels, pad_to_multiple_of=self.pad_to_multiple_of_labels,
pad_to_multiple_of=self.pad_to_multiple_of_labels, return_tensors="pt",
return_tensors="pt", )
)
# replace padding with -100 to ignore loss correctly # replace padding with -100 to ignore loss correctly
labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100) labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
...@@ -414,10 +413,11 @@ def main(): ...@@ -414,10 +413,11 @@ def main():
assert ( assert (
len(set(batch["sampling_rate"])) == 1 len(set(batch["sampling_rate"])) == 1
), f"Make sure all inputs have the same sampling rate of {processor.feature_extractor.sampling_rate}." ), f"Make sure all inputs have the same sampling rate of {processor.feature_extractor.sampling_rate}."
batch["input_values"] = processor(batch["speech"], sampling_rate=batch["sampling_rate"][0]).input_values
# Setup the processor for targets processed_batch = processor(
with processor.as_target_processor(): audio=batch["speech"], text=batch["target_text"], sampling_rate=batch["sampling_rate"][0]
batch["labels"] = processor(batch["target_text"]).input_ids )
batch.update(processed_batch)
return batch return batch
train_dataset = train_dataset.map( train_dataset = train_dataset.map(
......
...@@ -349,13 +349,12 @@ class SpeechDataCollatorWithPadding: ...@@ -349,13 +349,12 @@ class SpeechDataCollatorWithPadding:
if self.pad_labels: if self.pad_labels:
label_features = [{"input_ids": feature["labels"]} for feature in features] label_features = [{"input_ids": feature["labels"]} for feature in features]
with self.processor.as_target_processor(): labels_batch = self.processor.pad(
labels_batch = self.processor.pad( labels=label_features,
label_features, padding=self.padding,
padding=self.padding, pad_to_multiple_of=self.pad_to_multiple_of_labels,
pad_to_multiple_of=self.pad_to_multiple_of_labels, return_tensors="pt",
return_tensors="pt", )
)
# replace padding with -100 to ignore loss correctly # replace padding with -100 to ignore loss correctly
labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100) labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
......
...@@ -504,9 +504,8 @@ def main(): ...@@ -504,9 +504,8 @@ def main():
inputs = [prefix + inp for inp in inputs] inputs = [prefix + inp for inp in inputs]
model_inputs = tokenizer(inputs, max_length=data_args.max_source_length, padding=padding, truncation=True) model_inputs = tokenizer(inputs, max_length=data_args.max_source_length, padding=padding, truncation=True)
# Setup the tokenizer for targets # Tokenize targets with the `text_target` keyword argument
with tokenizer.as_target_tokenizer(): labels = tokenizer(text_target=targets, max_length=max_target_length, padding=padding, truncation=True)
labels = tokenizer(targets, max_length=max_target_length, padding=padding, truncation=True)
# If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore # If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore
# padding in the loss. # padding in the loss.
......
...@@ -458,9 +458,8 @@ def main(): ...@@ -458,9 +458,8 @@ def main():
inputs = [prefix + inp for inp in inputs] inputs = [prefix + inp for inp in inputs]
model_inputs = tokenizer(inputs, max_length=data_args.max_source_length, padding=padding, truncation=True) model_inputs = tokenizer(inputs, max_length=data_args.max_source_length, padding=padding, truncation=True)
# Setup the tokenizer for targets # Tokenize targets with the `text_target` keyword argument
with tokenizer.as_target_tokenizer(): labels = tokenizer(text_target=targets, max_length=max_target_length, padding=padding, truncation=True)
labels = tokenizer(targets, max_length=max_target_length, padding=padding, truncation=True)
# If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore # If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore
# padding in the loss. # padding in the loss.
......
...@@ -1612,9 +1612,8 @@ class TFHubertForCTC(TFHubertPreTrainedModel): ...@@ -1612,9 +1612,8 @@ class TFHubertForCTC(TFHubertPreTrainedModel):
>>> # compute loss >>> # compute loss
>>> target_transcription = "A MAN SAID TO THE UNIVERSE SIR I EXIST" >>> target_transcription = "A MAN SAID TO THE UNIVERSE SIR I EXIST"
>>> # wrap processor as target processor to encode labels >>> # Pass the transcription as text to encode labels
>>> with processor.as_target_processor(): >>> labels = processor(text=transcription, return_tensors="tf").input_values
... labels = processor(transcription, return_tensors="tf").input_values
>>> loss = model(input_values, labels=labels).loss >>> loss = model(input_values, labels=labels).loss
```""" ```"""
......
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
"""Tokenization classes for M2M100.""" """Tokenization classes for M2M100."""
import json import json
import os import os
from contextlib import contextmanager
from pathlib import Path from pathlib import Path
from shutil import copyfile from shutil import copyfile
from typing import Any, Dict, List, Optional, Tuple, Union from typing import Any, Dict, List, Optional, Tuple, Union
...@@ -116,10 +115,8 @@ class M2M100Tokenizer(PreTrainedTokenizer): ...@@ -116,10 +115,8 @@ class M2M100Tokenizer(PreTrainedTokenizer):
>>> tokenizer = M2M100Tokenizer.from_pretrained("facebook/m2m100_418M", src_lang="en", tgt_lang="ro") >>> tokenizer = M2M100Tokenizer.from_pretrained("facebook/m2m100_418M", src_lang="en", tgt_lang="ro")
>>> src_text = " UN Chief Says There Is No Military Solution in Syria" >>> src_text = " UN Chief Says There Is No Military Solution in Syria"
>>> tgt_text = "Şeful ONU declară că nu există o soluţie militară în Siria" >>> tgt_text = "Şeful ONU declară că nu există o soluţie militară în Siria"
>>> model_inputs = tokenizer(src_text, return_tensors="pt") >>> model_inputs = tokenizer(src_text, text_target=tgt_text, return_tensors="pt")
>>> with tokenizer.as_target_tokenizer(): >>> model(**model_inputs) # should work
... labels = tokenizer(tgt_text, return_tensors="pt").input_ids
>>> model(**model_inputs, labels=labels) # should work
```""" ```"""
vocab_files_names = VOCAB_FILES_NAMES vocab_files_names = VOCAB_FILES_NAMES
...@@ -346,16 +343,12 @@ class M2M100Tokenizer(PreTrainedTokenizer): ...@@ -346,16 +343,12 @@ class M2M100Tokenizer(PreTrainedTokenizer):
inputs["forced_bos_token_id"] = tgt_lang_id inputs["forced_bos_token_id"] = tgt_lang_id
return inputs return inputs
@contextmanager def _switch_to_input_mode(self):
def as_target_tokenizer(self):
"""
Temporarily sets the tokenizer for encoding the targets. Useful for tokenizer associated to
sequence-to-sequence models that need a slightly different processing for the labels.
"""
self.set_tgt_lang_special_tokens(self.tgt_lang)
yield
self.set_src_lang_special_tokens(self.src_lang) self.set_src_lang_special_tokens(self.src_lang)
def _switch_to_target_mode(self):
self.set_tgt_lang_special_tokens(self.tgt_lang)
def set_src_lang_special_tokens(self, src_lang: str) -> None: def set_src_lang_special_tokens(self, src_lang: str) -> None:
"""Reset the special tokens to the source lang setting. No prefix and suffix=[eos, src_lang_code].""" """Reset the special tokens to the source lang setting. No prefix and suffix=[eos, src_lang_code]."""
lang_token = self.get_lang_token(src_lang) lang_token = self.get_lang_token(src_lang)
......
...@@ -15,7 +15,6 @@ import json ...@@ -15,7 +15,6 @@ import json
import os import os
import re import re
import warnings import warnings
from contextlib import contextmanager
from pathlib import Path from pathlib import Path
from shutil import copyfile from shutil import copyfile
from typing import Any, Dict, List, Optional, Tuple, Union from typing import Any, Dict, List, Optional, Tuple, Union
...@@ -112,10 +111,7 @@ class MarianTokenizer(PreTrainedTokenizer): ...@@ -112,10 +111,7 @@ class MarianTokenizer(PreTrainedTokenizer):
>>> tokenizer = MarianTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-de") >>> tokenizer = MarianTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-de")
>>> src_texts = ["I am a small frog.", "Tom asked his teacher for advice."] >>> src_texts = ["I am a small frog.", "Tom asked his teacher for advice."]
>>> tgt_texts = ["Ich bin ein kleiner Frosch.", "Tom bat seinen Lehrer um Rat."] # optional >>> tgt_texts = ["Ich bin ein kleiner Frosch.", "Tom bat seinen Lehrer um Rat."] # optional
>>> inputs = tokenizer(src_texts, return_tensors="pt", padding=True) >>> inputs = tokenizer(src_texts, text_target=tgt_texts, return_tensors="pt", padding=True)
>>> with tokenizer.as_target_tokenizer():
... labels = tokenizer(tgt_texts, return_tensors="pt", padding=True)
>>> inputs["labels"] = labels["input_ids"]
# keys [input_ids, attention_mask, labels]. # keys [input_ids, attention_mask, labels].
>>> outputs = model(**inputs) # should work >>> outputs = model(**inputs) # should work
...@@ -281,18 +277,14 @@ class MarianTokenizer(PreTrainedTokenizer): ...@@ -281,18 +277,14 @@ class MarianTokenizer(PreTrainedTokenizer):
# We don't expect to process pairs, but leave the pair logic for API consistency # We don't expect to process pairs, but leave the pair logic for API consistency
return token_ids_0 + token_ids_1 + [self.eos_token_id] return token_ids_0 + token_ids_1 + [self.eos_token_id]
@contextmanager def _switch_to_input_mode(self):
def as_target_tokenizer(self): self.current_spm = self.spm_source
""" self.current_encoder = self.encoder
Temporarily sets the tokenizer for encoding the targets. Useful for tokenizer associated to
sequence-to-sequence models that need a slightly different processing for the labels. def _switch_to_target_mode(self):
"""
self.current_spm = self.spm_target self.current_spm = self.spm_target
if self.separate_vocabs: if self.separate_vocabs:
self.current_encoder = self.target_encoder self.current_encoder = self.target_encoder
yield
self.current_spm = self.spm_source
self.current_encoder = self.encoder
@property @property
def vocab_size(self) -> int: def vocab_size(self) -> int:
......
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
# limitations under the License. # limitations under the License.
import os import os
from contextlib import contextmanager
from shutil import copyfile from shutil import copyfile
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple
...@@ -69,10 +68,7 @@ class MBartTokenizer(PreTrainedTokenizer): ...@@ -69,10 +68,7 @@ class MBartTokenizer(PreTrainedTokenizer):
>>> tokenizer = MBartTokenizer.from_pretrained("facebook/mbart-large-en-ro", src_lang="en_XX", tgt_lang="ro_RO") >>> tokenizer = MBartTokenizer.from_pretrained("facebook/mbart-large-en-ro", src_lang="en_XX", tgt_lang="ro_RO")
>>> example_english_phrase = " UN Chief Says There Is No Military Solution in Syria" >>> example_english_phrase = " UN Chief Says There Is No Military Solution in Syria"
>>> expected_translation_romanian = "Şeful ONU declară că nu există o soluţie militară în Siria" >>> expected_translation_romanian = "Şeful ONU declară că nu există o soluţie militară în Siria"
>>> inputs = tokenizer(example_english_phrase, return_tensors="pt") >>> inputs = tokenizer(example_english_phrase, text_target=expected_translation_romanian, return_tensors="pt")
>>> with tokenizer.as_target_tokenizer():
... labels = tokenizer(expected_translation_romanian, return_tensors="pt")
>>> inputs["labels"] = labels["input_ids"]
```""" ```"""
vocab_files_names = VOCAB_FILES_NAMES vocab_files_names = VOCAB_FILES_NAMES
...@@ -340,15 +336,11 @@ class MBartTokenizer(PreTrainedTokenizer): ...@@ -340,15 +336,11 @@ class MBartTokenizer(PreTrainedTokenizer):
self.tgt_lang = tgt_lang self.tgt_lang = tgt_lang
return super().prepare_seq2seq_batch(src_texts, tgt_texts, **kwargs) return super().prepare_seq2seq_batch(src_texts, tgt_texts, **kwargs)
@contextmanager def _switch_to_input_mode(self):
def as_target_tokenizer(self): return self.set_src_lang_special_tokens(self.src_lang)
"""
Temporarily sets the tokenizer for encoding the targets. Useful for tokenizer associated to def _switch_to_target_mode(self):
sequence-to-sequence models that need a slightly different processing for the labels. return self.set_tgt_lang_special_tokens(self.tgt_lang)
"""
self.set_tgt_lang_special_tokens(self.tgt_lang)
yield
self.set_src_lang_special_tokens(self.src_lang)
def set_src_lang_special_tokens(self, src_lang) -> None: def set_src_lang_special_tokens(self, src_lang) -> None:
"""Reset the special tokens to the source lang setting. No prefix and suffix=[eos, src_lang_code].""" """Reset the special tokens to the source lang setting. No prefix and suffix=[eos, src_lang_code]."""
......
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
# limitations under the License. # limitations under the License.
import os import os
from contextlib import contextmanager
from shutil import copyfile from shutil import copyfile
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
...@@ -82,10 +81,7 @@ class MBartTokenizerFast(PreTrainedTokenizerFast): ...@@ -82,10 +81,7 @@ class MBartTokenizerFast(PreTrainedTokenizerFast):
... ) ... )
>>> example_english_phrase = " UN Chief Says There Is No Military Solution in Syria" >>> example_english_phrase = " UN Chief Says There Is No Military Solution in Syria"
>>> expected_translation_romanian = "Şeful ONU declară că nu există o soluţie militară în Siria" >>> expected_translation_romanian = "Şeful ONU declară că nu există o soluţie militară în Siria"
>>> inputs = tokenizer(example_english_phrase, return_tensors="pt") >>> inputs = tokenizer(example_english_phrase, text_target=expected_translation_romanian, return_tensors="pt")
>>> with tokenizer.as_target_tokenizer():
... labels = tokenizer(expected_translation_romanian, return_tensors="pt")
>>> inputs["labels"] = labels["input_ids"]
```""" ```"""
vocab_files_names = VOCAB_FILES_NAMES vocab_files_names = VOCAB_FILES_NAMES
...@@ -240,15 +236,11 @@ class MBartTokenizerFast(PreTrainedTokenizerFast): ...@@ -240,15 +236,11 @@ class MBartTokenizerFast(PreTrainedTokenizerFast):
self.tgt_lang = tgt_lang self.tgt_lang = tgt_lang
return super().prepare_seq2seq_batch(src_texts, tgt_texts, **kwargs) return super().prepare_seq2seq_batch(src_texts, tgt_texts, **kwargs)
@contextmanager def _switch_to_input_mode(self):
def as_target_tokenizer(self): return self.set_src_lang_special_tokens(self.src_lang)
"""
Temporarily sets the tokenizer for encoding the targets. Useful for tokenizer associated to def _switch_to_target_mode(self):
sequence-to-sequence models that need a slightly different processing for the labels. return self.set_tgt_lang_special_tokens(self.tgt_lang)
"""
self.set_tgt_lang_special_tokens(self.tgt_lang)
yield
self.set_src_lang_special_tokens(self.src_lang)
def set_src_lang_special_tokens(self, src_lang) -> None: def set_src_lang_special_tokens(self, src_lang) -> None:
"""Reset the special tokens to the source lang setting. No prefix and suffix=[eos, src_lang_code].""" """Reset the special tokens to the source lang setting. No prefix and suffix=[eos, src_lang_code]."""
......
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
# limitations under the License. # limitations under the License.
import os import os
from contextlib import contextmanager
from shutil import copyfile from shutil import copyfile
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple
...@@ -102,10 +101,8 @@ class MBart50Tokenizer(PreTrainedTokenizer): ...@@ -102,10 +101,8 @@ class MBart50Tokenizer(PreTrainedTokenizer):
>>> tokenizer = MBart50Tokenizer.from_pretrained("facebook/mbart-large-50", src_lang="en_XX", tgt_lang="ro_RO") >>> tokenizer = MBart50Tokenizer.from_pretrained("facebook/mbart-large-50", src_lang="en_XX", tgt_lang="ro_RO")
>>> src_text = " UN Chief Says There Is No Military Solution in Syria" >>> src_text = " UN Chief Says There Is No Military Solution in Syria"
>>> tgt_text = "Şeful ONU declară că nu există o soluţie militară în Siria" >>> tgt_text = "Şeful ONU declară că nu există o soluţie militară în Siria"
>>> model_inputs = tokenizer(src_text, return_tensors="pt") >>> model_inputs = tokenizer(src_text, text_target=tgt_text, return_tensors="pt")
>>> with tokenizer.as_target_tokenizer(): >>> # model(**model_inputs) should work
... labels = tokenizer(tgt_text, return_tensors="pt").input_ids
>>> # model(**model_inputs, labels=labels) should work
```""" ```"""
vocab_files_names = VOCAB_FILES_NAMES vocab_files_names = VOCAB_FILES_NAMES
...@@ -337,15 +334,11 @@ class MBart50Tokenizer(PreTrainedTokenizer): ...@@ -337,15 +334,11 @@ class MBart50Tokenizer(PreTrainedTokenizer):
self.tgt_lang = tgt_lang self.tgt_lang = tgt_lang
return super().prepare_seq2seq_batch(src_texts, tgt_texts, **kwargs) return super().prepare_seq2seq_batch(src_texts, tgt_texts, **kwargs)
@contextmanager def _switch_to_input_mode(self):
def as_target_tokenizer(self): return self.set_src_lang_special_tokens(self.src_lang)
"""
Temporarily sets the tokenizer for encoding the targets. Useful for tokenizer associated to def _switch_to_target_mode(self):
sequence-to-sequence models that need a slightly different processing for the labels. return self.set_tgt_lang_special_tokens(self.tgt_lang)
"""
self.set_tgt_lang_special_tokens(self.tgt_lang)
yield
self.set_src_lang_special_tokens(self.src_lang)
def set_src_lang_special_tokens(self, src_lang: str) -> None: def set_src_lang_special_tokens(self, src_lang: str) -> None:
"""Reset the special tokens to the source lang setting. prefix=[src_lang_code] and suffix=[eos].""" """Reset the special tokens to the source lang setting. prefix=[src_lang_code] and suffix=[eos]."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment