Unverified Commit 9336086a authored by Sam Shleifer's avatar Sam Shleifer Committed by GitHub
Browse files

prepare_seq2seq_batch makes labels/ decoder_input_ids made later. (#6654)

* broken test

* batch parity

* tests pass

* boom boom

* boom boom

* split out bart tokenizer tests

* fix tests

* boom boom

* Fixed dataset bug

* Fix marian

* Undo extra

* Get marian working

* Fix t5 tok tests

* Test passing

* Cleanup

* better assert msg

* require torch

* Fix mbart tests

* undo extra decoder_attn_mask change

* Fix import

* pegasus tokenizer can ignore src_lang kwargs

* unused kwarg test cov

* boom boom

* add todo for pegasus issue

* cover one word translation edge case

* Cleanup

* doc
parent cb276b41
...@@ -71,8 +71,8 @@ Summarization Tips: ...@@ -71,8 +71,8 @@ Summarization Tips:
(It rarely makes sense to start from `bart-large` unless you are a researching finetuning methods). (It rarely makes sense to start from `bart-large` unless you are a researching finetuning methods).
**Update 2018-07-18** **Update 2018-07-18**
Datasets: `Seq2SeqDataset` should be used for all tokenizers without a `prepare_seq2seq_batch` method. For those who do (like Marian, MBart), `TranslationDataset` should be used.** Datasets: `LegacySeq2SeqDataset` will be used for all tokenizers without a `prepare_seq2seq_batch` method. Otherwise, `Seq2SeqDataset` will be used.
A new dataset is needed to support multilingual tasks. Future work/help wanted: A new dataset to support multilingual tasks.
### Command Line Options ### Command Line Options
...@@ -106,7 +106,7 @@ The following command should work on a 16GB GPU: ...@@ -106,7 +106,7 @@ The following command should work on a 16GB GPU:
--train_batch_size=1 \ --train_batch_size=1 \
--eval_batch_size=1 \ --eval_batch_size=1 \
--output_dir=xsum_results \ --output_dir=xsum_results \
--num_train_epochs 1 \ --num_train_epochs 6 \
--model_name_or_path facebook/bart-large --model_name_or_path facebook/bart-large
``` ```
......
import argparse import argparse
import gc import gc
import os import os
import warnings
from pathlib import Path from pathlib import Path
from typing import List from typing import List
...@@ -11,6 +12,7 @@ from torch.nn import functional as F ...@@ -11,6 +12,7 @@ from torch.nn import functional as F
from lightning_base import generic_train from lightning_base import generic_train
from transformers import AutoModelForSeq2SeqLM, MBartTokenizer, T5Config, T5ForConditionalGeneration from transformers import AutoModelForSeq2SeqLM, MBartTokenizer, T5Config, T5ForConditionalGeneration
from transformers.modeling_bart import shift_tokens_right
try: try:
...@@ -22,6 +24,7 @@ try: ...@@ -22,6 +24,7 @@ try:
assert_all_frozen, assert_all_frozen,
calculate_bleu, calculate_bleu,
freeze_params, freeze_params,
label_smoothed_nll_loss,
pickle_load, pickle_load,
use_task_specific_params, use_task_specific_params,
) )
...@@ -34,12 +37,15 @@ except ImportError: ...@@ -34,12 +37,15 @@ except ImportError:
assert_all_frozen, assert_all_frozen,
calculate_bleu, calculate_bleu,
freeze_params, freeze_params,
label_smoothed_nll_loss,
pickle_load, pickle_load,
use_task_specific_params, use_task_specific_params,
) )
class BartSummarizationDistiller(SummarizationModule): class BartSummarizationDistiller(SummarizationModule):
"""Supports Bart, Pegasus and other models that inherit from Bart."""
loss_names = ["loss", "ce_loss", "mlm_loss", "enc_mse_loss", "hid_loss_enc", "hid_loss_dec"] loss_names = ["loss", "ce_loss", "mlm_loss", "enc_mse_loss", "hid_loss_enc", "hid_loss_dec"]
def __init__(self, hparams): def __init__(self, hparams):
...@@ -160,22 +166,32 @@ class BartSummarizationDistiller(SummarizationModule): ...@@ -160,22 +166,32 @@ class BartSummarizationDistiller(SummarizationModule):
def _step(self, batch): def _step(self, batch):
# assert is_frozen(self.teacher) # assert is_frozen(self.teacher)
pad_token_id = self.tokenizer.pad_token_id pad_token_id = self.tokenizer.pad_token_id
input_ids, src_mask, y = batch["input_ids"], batch["attention_mask"], batch["decoder_input_ids"] input_ids, src_mask, tgt_ids = batch["input_ids"], batch["attention_mask"], batch["labels"]
decoder_input_ids = y[:, :-1].contiguous() decoder_input_ids = shift_tokens_right(tgt_ids, pad_token_id)
labels = y[:, 1:].clone()
labels[y[:, 1:] == pad_token_id] = -100
# noinspection PyCallingNonCallable # noinspection PyCallingNonCallable
sloss, slogits, dec_hidden, enc_outputs, enc_hidden_state = self( lm_logits, dec_hidden, enc_outputs, enc_hidden_state = self(
input_ids, input_ids,
attention_mask=src_mask, attention_mask=src_mask,
decoder_input_ids=decoder_input_ids, decoder_input_ids=decoder_input_ids,
labels=labels,
output_hidden_states=True, output_hidden_states=True,
output_attentions=False, output_attentions=False,
use_cache=False,
) # TODO(@sshleifer): return_dict=True cleanup
# Same cross entropy vs. label smoothing logic as finetune.py
assert lm_logits.shape[-1] == self.model.config.vocab_size
if self.hparams.label_smoothing == 0:
# Same behavior as modeling_bart.py, besides ignoring pad_token_id
loss_fct = torch.nn.CrossEntropyLoss(ignore_index=pad_token_id)
student_lm_loss = loss_fct(lm_logits.view(-1, lm_logits.shape[-1]), tgt_ids.view(-1))
else:
lprobs = torch.nn.functional.log_softmax(lm_logits, dim=-1)
student_lm_loss, _ = label_smoothed_nll_loss(
lprobs, tgt_ids, self.hparams.label_smoothing, ignore_index=pad_token_id
) )
def zero_tensor(): def zero_tensor():
return torch.tensor(0.0).type_as(sloss) return torch.tensor(0.0).type_as(student_lm_loss)
loss_encoder, hid_loss_enc, hid_loss_dec = zero_tensor(), zero_tensor(), zero_tensor() loss_encoder, hid_loss_enc, hid_loss_dec = zero_tensor(), zero_tensor(), zero_tensor()
if self.different_encoder: if self.different_encoder:
...@@ -199,29 +215,26 @@ class BartSummarizationDistiller(SummarizationModule): ...@@ -199,29 +215,26 @@ class BartSummarizationDistiller(SummarizationModule):
attention_mask=src_mask, attention_mask=src_mask,
encoder_outputs=teacher_enc_outputs, encoder_outputs=teacher_enc_outputs,
decoder_input_ids=decoder_input_ids, decoder_input_ids=decoder_input_ids,
lm_labels=labels, lm_labels=tgt_ids,
output_hidden_states=True, output_hidden_states=True,
) )
dec_mask = decoder_input_ids.ne(pad_token_id) dec_mask = decoder_input_ids.ne(pad_token_id)
loss_ce, s_logits_slct, t_logits_slct = self.calc_ce_loss(dec_mask, slogits, tlogits) loss_ce, s_logits_slct, t_logits_slct = self.calc_ce_loss(dec_mask, lm_logits, tlogits)
if self.alpha_hid > 0: if self.alpha_hid > 0:
hid_loss_dec = self.calc_hidden_loss(dec_mask, dec_hidden, tdec_hidden, self.hparams.d_layer_to_copy) hid_loss_dec = self.calc_hidden_loss(dec_mask, dec_hidden, tdec_hidden, self.hparams.d_layer_to_copy)
blended_loss = ( blended_loss = (
self.alpha_ce * loss_ce self.alpha_ce * loss_ce
+ self.alpha_mlm * sloss + self.alpha_mlm * student_lm_loss
+ self.hparams.alpha_encoder_loss * loss_encoder + self.hparams.alpha_encoder_loss * loss_encoder
+ self.hparams.alpha_hid * (hid_loss_enc + hid_loss_dec) + self.hparams.alpha_hid * (hid_loss_enc + hid_loss_dec)
) )
return blended_loss, loss_ce, sloss, loss_encoder, hid_loss_enc, hid_loss_dec return blended_loss, loss_ce, student_lm_loss, loss_encoder, hid_loss_enc, hid_loss_dec
def calc_hidden_loss(self, attention_mask, hidden_states, hidden_states_T, matches): def calc_hidden_loss(self, attention_mask, hidden_states, hidden_states_T, matches):
assert not isinstance( msg = "expected list or tuple for hidden_states, got tensor of shape: "
hidden_states, torch.Tensor assert not isinstance(hidden_states, torch.Tensor), f"{msg}{hidden_states.shape}"
), f"expected list or tuple for hidden_states, got tensor of shape {hidden_states.shape}" assert not isinstance(hidden_states_T, torch.Tensor), f"{msg}{hidden_states_T.shape}"
assert not isinstance(
hidden_states_T, torch.Tensor
), f"expected list or tuple for hidden_states_T, got tensor of shape {hidden_states_T.shape}"
mask = attention_mask.to(hidden_states[0]) mask = attention_mask.to(hidden_states[0])
valid_count = mask.sum() * hidden_states[0].size(-1) valid_count = mask.sum() * hidden_states[0].size(-1)
hidden_losses = [ hidden_losses = [
...@@ -233,7 +246,7 @@ class BartSummarizationDistiller(SummarizationModule): ...@@ -233,7 +246,7 @@ class BartSummarizationDistiller(SummarizationModule):
def add_distill_args(parser): def add_distill_args(parser):
parser.add_argument("--teacher", default="facebook/bart-large-cnn", type=str) parser.add_argument("--teacher", type=str)
parser.add_argument("--alpha_ce", default=0.8, type=float) parser.add_argument("--alpha_ce", default=0.8, type=float)
parser.add_argument("--alpha_mlm", default=0.2, type=float) parser.add_argument("--alpha_mlm", default=0.2, type=float)
parser.add_argument("--alpha_encoder_loss", default=0.0, type=float) parser.add_argument("--alpha_encoder_loss", default=0.0, type=float)
...@@ -245,8 +258,9 @@ def add_distill_args(parser): ...@@ -245,8 +258,9 @@ def add_distill_args(parser):
class BartTranslationDistiller(BartSummarizationDistiller): class BartTranslationDistiller(BartSummarizationDistiller):
"""Supports Mbart, Marian, other models that inherit from Bart."""
mode = "translation" mode = "translation"
loss_names = ["loss"]
metric_names = ["bleu"] metric_names = ["bleu"]
val_metric = "bleu" val_metric = "bleu"
...@@ -368,7 +382,7 @@ class T5SummarizationDistiller(BartSummarizationDistiller): ...@@ -368,7 +382,7 @@ class T5SummarizationDistiller(BartSummarizationDistiller):
attention_mask=source_mask, attention_mask=source_mask,
encoder_outputs=teacher_enc_outputs, encoder_outputs=teacher_enc_outputs,
decoder_input_ids=decoder_input_ids, decoder_input_ids=decoder_input_ids,
lm_labels=labels, labels=labels,
output_hidden_states=True, output_hidden_states=True,
use_cache=False, use_cache=False,
) )
...@@ -402,6 +416,7 @@ def create_module(args): ...@@ -402,6 +416,7 @@ def create_module(args):
def evaluate_checkpoint(ckpt_path: Path, dest_dir=None): def evaluate_checkpoint(ckpt_path: Path, dest_dir=None):
# TODO(SS): DELETE?
exp_dir = ckpt_path.parent exp_dir = ckpt_path.parent
if dest_dir is None: if dest_dir is None:
dest_dir = exp_dir dest_dir = exp_dir
...@@ -424,21 +439,19 @@ def evaluate_checkpoint(ckpt_path: Path, dest_dir=None): ...@@ -424,21 +439,19 @@ def evaluate_checkpoint(ckpt_path: Path, dest_dir=None):
trainer.test(model) trainer.test(model)
def get_layers_to_copy(n_to_get, tot): LAYERS_TO_COPY = {
all_layers = list(range(tot)) # maps num layers in student -> which teacher layers to copy.
if tot == 12: # Alternating for special cases # 12: bart, 16: pegasus, 6: marian/Helsinki-NLP
layers_to_copy = { # maps num layers in student -> which teacher layers to copy 12: {
1: [0], 1: [0],
2: [0, 6], 2: [0, 6],
3: [0, 6, 11], 3: [0, 6, 11],
4: [0, 4, 8, 11], 4: [0, 4, 8, 11],
6: [0, 2, 4, 7, 9, 11], 6: [0, 2, 4, 7, 9, 11],
9: [0, 1, 2, 4, 5, 7, 9, 10, 11], 9: [0, 1, 2, 4, 5, 7, 9, 10, 11],
12: all_layers, 12: list(range(12)),
} },
return layers_to_copy[n_to_get] 16: { # maps num layers in student -> which teacher layers to copy
elif tot == 16:
layers_to_copy = { # maps num layers in student -> which teacher layers to copy
1: [0], 1: [0],
2: [0, 8], 2: [0, 8],
3: [0, 8, 15], 3: [0, 8, 15],
...@@ -446,11 +459,20 @@ def get_layers_to_copy(n_to_get, tot): ...@@ -446,11 +459,20 @@ def get_layers_to_copy(n_to_get, tot):
6: [0, 3, 6, 9, 12, 15], 6: [0, 3, 6, 9, 12, 15],
8: [0, 2, 4, 6, 8, 10, 12, 15], 8: [0, 2, 4, 6, 8, 10, 12, 15],
9: [0, 1, 3, 5, 7, 9, 11, 13, 15], 9: [0, 1, 3, 5, 7, 9, 11, 13, 15],
16: all_layers, 16: list(range(16)),
} },
return layers_to_copy[n_to_get] 6: {1: [0], 2: [0, 5], 3: [0, 2, 5], 4: [0, 1, 3, 5], 6: list(range(6))},
else: }
return all_layers[:n_to_get] # TODO: better version on theseus-bart branch
def get_layers_to_copy(n_student, n_teacher):
try:
return LAYERS_TO_COPY[n_teacher][n_student]
except KeyError:
warnings.warn(
f"no hardcoded layers to copy for teacher {n_teacher} -> student {n_student}, defaulting to first {n_student}"
)
return list(range(n_student))
def distill_main(args): def distill_main(args):
......
...@@ -13,15 +13,16 @@ import torch ...@@ -13,15 +13,16 @@ import torch
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from lightning_base import BaseTransformer, add_generic_args, generic_train from lightning_base import BaseTransformer, add_generic_args, generic_train
from transformers import MarianTokenizer, MBartTokenizer, T5ForConditionalGeneration from transformers import MBartTokenizer, T5ForConditionalGeneration
from transformers.modeling_bart import shift_tokens_right
try: try:
from .callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback, get_early_stopping_callback from .callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback, get_early_stopping_callback
from .utils import ( from .utils import (
ROUGE_KEYS, ROUGE_KEYS,
LegacySeq2SeqDataset,
Seq2SeqDataset, Seq2SeqDataset,
TranslationDataset,
assert_all_frozen, assert_all_frozen,
calculate_bleu, calculate_bleu,
calculate_rouge, calculate_rouge,
...@@ -39,8 +40,8 @@ except ImportError: ...@@ -39,8 +40,8 @@ except ImportError:
from callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback, get_early_stopping_callback from callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback, get_early_stopping_callback
from utils import ( from utils import (
ROUGE_KEYS, ROUGE_KEYS,
LegacySeq2SeqDataset,
Seq2SeqDataset, Seq2SeqDataset,
TranslationDataset,
assert_all_frozen, assert_all_frozen,
calculate_bleu, calculate_bleu,
calculate_rouge, calculate_rouge,
...@@ -102,14 +103,13 @@ class SummarizationModule(BaseTransformer): ...@@ -102,14 +103,13 @@ class SummarizationModule(BaseTransformer):
self.hparams.git_sha = get_git_info()["repo_sha"] self.hparams.git_sha = get_git_info()["repo_sha"]
self.num_workers = hparams.num_workers self.num_workers = hparams.num_workers
self.decoder_start_token_id = None self.decoder_start_token_id = None # default to config
if self.model.config.decoder_start_token_id is None and isinstance(self.tokenizer, MBartTokenizer): if self.model.config.decoder_start_token_id is None and isinstance(self.tokenizer, MBartTokenizer):
self.decoder_start_token_id = self.tokenizer.lang_code_to_id[hparams.tgt_lang] self.decoder_start_token_id = self.tokenizer.lang_code_to_id[hparams.tgt_lang]
self.model.config.decoder_start_token_id = self.decoder_start_token_id self.model.config.decoder_start_token_id = self.decoder_start_token_id
if isinstance(self.tokenizer, MBartTokenizer) or isinstance(self.tokenizer, MarianTokenizer): self.dataset_class = (
self.dataset_class = TranslationDataset Seq2SeqDataset if hasattr(self.tokenizer, "prepare_seq2seq_batch") else LegacySeq2SeqDataset
else: )
self.dataset_class = Seq2SeqDataset
def freeze_embeds(self): def freeze_embeds(self):
"""Freeze token embeddings and positional embeddings for bart, just token embeddings for t5.""" """Freeze token embeddings and positional embeddings for bart, just token embeddings for t5."""
...@@ -134,27 +134,25 @@ class SummarizationModule(BaseTransformer): ...@@ -134,27 +134,25 @@ class SummarizationModule(BaseTransformer):
def _step(self, batch: dict) -> Tuple: def _step(self, batch: dict) -> Tuple:
pad_token_id = self.tokenizer.pad_token_id pad_token_id = self.tokenizer.pad_token_id
source_ids, source_mask, target_ids = batch["input_ids"], batch["attention_mask"], batch["decoder_input_ids"] src_ids, src_mask = batch["input_ids"], batch["attention_mask"]
tgt_ids = batch["labels"]
if isinstance(self.model, T5ForConditionalGeneration): if isinstance(self.model, T5ForConditionalGeneration):
decoder_input_ids = self.model._shift_right(target_ids) decoder_input_ids = self.model._shift_right(tgt_ids)
lm_labels = target_ids
else: else:
decoder_input_ids = target_ids[:, :-1].contiguous() # Why this line? decoder_input_ids = shift_tokens_right(tgt_ids, pad_token_id)
lm_labels = target_ids[:, 1:].clone() # why clone?
outputs = self(source_ids, attention_mask=source_mask, decoder_input_ids=decoder_input_ids, use_cache=False)
outputs = self(src_ids, attention_mask=src_mask, decoder_input_ids=decoder_input_ids, use_cache=False)
lm_logits = outputs[0]
if self.hparams.label_smoothing == 0: if self.hparams.label_smoothing == 0:
# Same behavior as modeling_bart.py # Same behavior as modeling_bart.py, besides ignoring pad_token_id
loss_fct = torch.nn.CrossEntropyLoss(ignore_index=pad_token_id) loss_fct = torch.nn.CrossEntropyLoss(ignore_index=pad_token_id)
lm_logits = outputs[0]
assert lm_logits.shape[-1] == self.model.config.vocab_size assert lm_logits.shape[-1] == self.model.config.vocab_size
loss = loss_fct(lm_logits.view(-1, lm_logits.shape[-1]), lm_labels.view(-1)) loss = loss_fct(lm_logits.view(-1, lm_logits.shape[-1]), tgt_ids.view(-1))
else: else:
lprobs = torch.nn.functional.log_softmax(outputs[0], dim=-1) lprobs = torch.nn.functional.log_softmax(lm_logits, dim=-1)
loss, nll_loss = label_smoothed_nll_loss( loss, nll_loss = label_smoothed_nll_loss(
lprobs, lm_labels, self.hparams.label_smoothing, ignore_index=pad_token_id lprobs, tgt_ids, self.hparams.label_smoothing, ignore_index=pad_token_id
) )
return (loss,) return (loss,)
...@@ -167,7 +165,7 @@ class SummarizationModule(BaseTransformer): ...@@ -167,7 +165,7 @@ class SummarizationModule(BaseTransformer):
logs = {name: loss for name, loss in zip(self.loss_names, loss_tensors)} logs = {name: loss for name, loss in zip(self.loss_names, loss_tensors)}
# tokens per batch # tokens per batch
logs["tpb"] = batch["input_ids"].ne(self.pad).sum() + batch["decoder_input_ids"].ne(self.pad).sum() logs["tpb"] = batch["input_ids"].ne(self.pad).sum() + batch["labels"].ne(self.pad).sum()
return {"loss": loss_tensors[0], "log": logs} return {"loss": loss_tensors[0], "log": logs}
def validation_step(self, batch, batch_idx) -> Dict: def validation_step(self, batch, batch_idx) -> Dict:
...@@ -204,7 +202,7 @@ class SummarizationModule(BaseTransformer): ...@@ -204,7 +202,7 @@ class SummarizationModule(BaseTransformer):
) )
gen_time = (time.time() - t0) / batch["input_ids"].shape[0] gen_time = (time.time() - t0) / batch["input_ids"].shape[0]
preds: List[str] = self.ids_to_clean_text(generated_ids) preds: List[str] = self.ids_to_clean_text(generated_ids)
target: List[str] = self.ids_to_clean_text(batch["decoder_input_ids"]) target: List[str] = self.ids_to_clean_text(batch["labels"])
loss_tensors = self._step(batch) loss_tensors = self._step(batch)
base_metrics = {name: loss for name, loss in zip(self.loss_names, loss_tensors)} base_metrics = {name: loss for name, loss in zip(self.loss_names, loss_tensors)}
rouge: Dict = self.calc_generative_metrics(preds, target) rouge: Dict = self.calc_generative_metrics(preds, target)
......
...@@ -132,4 +132,6 @@ def run_generate(): ...@@ -132,4 +132,6 @@ def run_generate():
if __name__ == "__main__": if __name__ == "__main__":
# Usage for MT:
# python run_eval.py MODEL_NAME $DATA_DIR/test.source $save_dir/test_translations.txt --reference_path $DATA_DIR/test.target --score_path $save_dir/test_bleu.json --task translation $@
run_generate() run_generate()
...@@ -10,18 +10,18 @@ from unittest.mock import patch ...@@ -10,18 +10,18 @@ from unittest.mock import patch
import pytest import pytest
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
from pytest import param
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
import lightning_base import lightning_base
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from transformers.modeling_bart import shift_tokens_right
from transformers.testing_utils import CaptureStderr, CaptureStdout, require_multigpu from transformers.testing_utils import CaptureStderr, CaptureStdout, require_multigpu
from .distillation import distill_main, evaluate_checkpoint from .distillation import distill_main, evaluate_checkpoint
from .finetune import SummarizationModule, main from .finetune import SummarizationModule, main
from .pack_dataset import pack_data_dir from .pack_dataset import pack_data_dir
from .run_eval import generate_summaries_or_translations, run_generate from .run_eval import generate_summaries_or_translations, run_generate
from .utils import Seq2SeqDataset, TranslationDataset, label_smoothed_nll_loss, lmap, load_json from .utils import LegacySeq2SeqDataset, Seq2SeqDataset, label_smoothed_nll_loss, lmap, load_json
logging.basicConfig(level=logging.DEBUG) logging.basicConfig(level=logging.DEBUG)
...@@ -452,18 +452,27 @@ def test_pack_dataset(): ...@@ -452,18 +452,27 @@ def test_pack_dataset():
assert orig_paths == new_paths assert orig_paths == new_paths
@pytest.mark.parametrize(["tok_name"], [pytest.param(MBART_TINY), pytest.param(MARIAN_TINY)]) @pytest.mark.parametrize(
def test_mbart_dataset_truncation(tok_name): ["tok_name"],
[
pytest.param(MBART_TINY),
pytest.param(MARIAN_TINY),
pytest.param(T5_TINY),
pytest.param(BART_TINY),
pytest.param("google/pegasus-xsum"),
],
)
def test_seq2seq_dataset_truncation(tok_name):
tokenizer = AutoTokenizer.from_pretrained(tok_name) tokenizer = AutoTokenizer.from_pretrained(tok_name)
tmp_dir = make_test_data_dir() tmp_dir = make_test_data_dir()
max_len_source = max(len(tokenizer.encode(a)) for a in ARTICLES) max_len_source = max(len(tokenizer.encode(a)) for a in ARTICLES)
max_len_target = max(len(tokenizer.encode(a)) for a in SUMMARIES) max_len_target = max(len(tokenizer.encode(a)) for a in SUMMARIES)
max_src_len = 4 max_src_len = 4
max_tgt_len = 8 max_tgt_len = 8
assert max_len_target > max_src_len # Truncated assert max_len_target > max_src_len # Will be truncated
assert max_len_source > max_src_len assert max_len_source > max_src_len # Will be truncated
src_lang, tgt_lang = "ro_RO", "de_DE" # NOT WHAT IT WAS TRAINED ON src_lang, tgt_lang = "ro_RO", "de_DE" # ignored for all but mbart, but never causes error.
train_dataset = TranslationDataset( train_dataset = Seq2SeqDataset(
tokenizer, tokenizer,
data_dir=tmp_dir, data_dir=tmp_dir,
type_path="train", type_path="train",
...@@ -479,10 +488,11 @@ def test_mbart_dataset_truncation(tok_name): ...@@ -479,10 +488,11 @@ def test_mbart_dataset_truncation(tok_name):
# show that articles were trimmed. # show that articles were trimmed.
assert batch["input_ids"].shape[1] == max_src_len assert batch["input_ids"].shape[1] == max_src_len
# show that targets are the same len # show that targets are the same len
assert batch["decoder_input_ids"].shape[1] == max_tgt_len assert batch["labels"].shape[1] == max_tgt_len
if tok_name == MARIAN_TINY: if tok_name != MBART_TINY:
continue continue
# check language codes in correct place # check language codes in correct place
batch["decoder_input_ids"] = shift_tokens_right(batch["labels"], tokenizer.pad_token_id)
assert batch["decoder_input_ids"][0, 0].item() == tokenizer.lang_code_to_id[tgt_lang] assert batch["decoder_input_ids"][0, 0].item() == tokenizer.lang_code_to_id[tgt_lang]
assert batch["decoder_input_ids"][0, -1].item() == tokenizer.eos_token_id assert batch["decoder_input_ids"][0, -1].item() == tokenizer.eos_token_id
assert batch["input_ids"][0, -2].item() == tokenizer.eos_token_id assert batch["input_ids"][0, -2].item() == tokenizer.eos_token_id
...@@ -491,14 +501,14 @@ def test_mbart_dataset_truncation(tok_name): ...@@ -491,14 +501,14 @@ def test_mbart_dataset_truncation(tok_name):
break # No need to test every batch break # No need to test every batch
@pytest.mark.parametrize(["tok"], [pytest.param(T5_TINY), pytest.param(BART_TINY), param(MARIAN_TINY)]) @pytest.mark.parametrize(["tok"], [pytest.param(BART_TINY), pytest.param("bert-base-cased")])
def test_summarization_dataset_truncation(tok): def test_legacy_dataset_truncation(tok):
tokenizer = AutoTokenizer.from_pretrained(tok) tokenizer = AutoTokenizer.from_pretrained(tok)
tmp_dir = make_test_data_dir() tmp_dir = make_test_data_dir()
max_len_source = max(len(tokenizer.encode(a)) for a in ARTICLES) max_len_source = max(len(tokenizer.encode(a)) for a in ARTICLES)
max_len_target = max(len(tokenizer.encode(a)) for a in SUMMARIES) max_len_target = max(len(tokenizer.encode(a)) for a in SUMMARIES)
trunc_target = 4 trunc_target = 4
train_dataset = Seq2SeqDataset( train_dataset = LegacySeq2SeqDataset(
tokenizer, tokenizer,
data_dir=tmp_dir, data_dir=tmp_dir,
type_path="train", type_path="train",
...@@ -512,6 +522,6 @@ def test_summarization_dataset_truncation(tok): ...@@ -512,6 +522,6 @@ def test_summarization_dataset_truncation(tok):
assert batch["input_ids"].shape[1] == max_len_source assert batch["input_ids"].shape[1] == max_len_source
assert 20 >= batch["input_ids"].shape[1] # trimmed significantly assert 20 >= batch["input_ids"].shape[1] # trimmed significantly
# show that targets were truncated # show that targets were truncated
assert batch["decoder_input_ids"].shape[1] == trunc_target # Truncated assert batch["labels"].shape[1] == trunc_target # Truncated
assert max_len_target > trunc_target # Truncated assert max_len_target > trunc_target # Truncated
break # No need to test every batch break # No need to test every batch
...@@ -3,7 +3,6 @@ import json ...@@ -3,7 +3,6 @@ import json
import linecache import linecache
import os import os
import pickle import pickle
import warnings
from logging import getLogger from logging import getLogger
from pathlib import Path from pathlib import Path
from typing import Callable, Dict, Iterable, List from typing import Callable, Dict, Iterable, List
...@@ -41,6 +40,7 @@ def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=-100): ...@@ -41,6 +40,7 @@ def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=-100):
def encode_line(tokenizer, line, max_length, pad_to_max_length=True, return_tensors="pt"): def encode_line(tokenizer, line, max_length, pad_to_max_length=True, return_tensors="pt"):
"""Only used by LegacyDataset"""
extra_kw = {"add_prefix_space": True} if isinstance(tokenizer, BartTokenizer) else {} extra_kw = {"add_prefix_space": True} if isinstance(tokenizer, BartTokenizer) else {}
return tokenizer( return tokenizer(
[line], [line],
...@@ -75,7 +75,7 @@ def trim_batch( ...@@ -75,7 +75,7 @@ def trim_batch(
return (input_ids[:, keep_column_mask], attention_mask[:, keep_column_mask]) return (input_ids[:, keep_column_mask], attention_mask[:, keep_column_mask])
class Seq2SeqDataset(Dataset): class AbstractSeq2SeqDataset(Dataset):
def __init__( def __init__(
self, self,
tokenizer, tokenizer,
...@@ -102,11 +102,28 @@ class Seq2SeqDataset(Dataset): ...@@ -102,11 +102,28 @@ class Seq2SeqDataset(Dataset):
self.pad_token_id = self.tokenizer.pad_token_id self.pad_token_id = self.tokenizer.pad_token_id
self.src_lang = src_lang self.src_lang = src_lang
self.tgt_lang = tgt_lang self.tgt_lang = tgt_lang
self.add_prefix_space = isinstance(self.tokenizer, BartTokenizer)
def __len__(self): def __len__(self):
return len(self.src_lens) return len(self.src_lens)
@staticmethod
def get_char_lens(data_file):
return [len(x) for x in Path(data_file).open().readlines()]
def make_sortish_sampler(self, batch_size):
return SortishSampler(self.src_lens, batch_size)
def __getitem__(self, item):
raise NotImplementedError("You must implement this")
def collate_fn(self, batch):
raise NotImplementedError("You must implement this")
class LegacySeq2SeqDataset(AbstractSeq2SeqDataset):
def __getitem__(self, index) -> Dict[str, torch.Tensor]: def __getitem__(self, index) -> Dict[str, torch.Tensor]:
"""Call tokenizer on src and tgt_lines"""
index = index + 1 # linecache starts at 1 index = index + 1 # linecache starts at 1
source_line = self.prefix + linecache.getline(str(self.src_file), index).rstrip("\n") source_line = self.prefix + linecache.getline(str(self.src_file), index).rstrip("\n")
tgt_line = linecache.getline(str(self.tgt_file), index).rstrip("\n") tgt_line = linecache.getline(str(self.tgt_file), index).rstrip("\n")
...@@ -121,42 +138,27 @@ class Seq2SeqDataset(Dataset): ...@@ -121,42 +138,27 @@ class Seq2SeqDataset(Dataset):
return { return {
"input_ids": source_ids, "input_ids": source_ids,
"attention_mask": src_mask, "attention_mask": src_mask,
"decoder_input_ids": target_ids, "labels": target_ids,
} }
@staticmethod
def get_char_lens(data_file):
return [len(x) for x in Path(data_file).open().readlines()]
def collate_fn(self, batch) -> Dict[str, torch.Tensor]: def collate_fn(self, batch) -> Dict[str, torch.Tensor]:
input_ids = torch.stack([x["input_ids"] for x in batch]) input_ids = torch.stack([x["input_ids"] for x in batch])
masks = torch.stack([x["attention_mask"] for x in batch]) masks = torch.stack([x["attention_mask"] for x in batch])
target_ids = torch.stack([x["decoder_input_ids"] for x in batch]) target_ids = torch.stack([x["labels"] for x in batch])
pad_token_id = self.pad_token_id pad_token_id = self.pad_token_id
y = trim_batch(target_ids, pad_token_id) y = trim_batch(target_ids, pad_token_id)
source_ids, source_mask = trim_batch(input_ids, pad_token_id, attention_mask=masks) source_ids, source_mask = trim_batch(input_ids, pad_token_id, attention_mask=masks)
batch = { batch = {
"input_ids": source_ids, "input_ids": source_ids,
"attention_mask": source_mask, "attention_mask": source_mask,
"decoder_input_ids": y, "labels": y,
} }
return batch return batch
def make_sortish_sampler(self, batch_size):
return SortishSampler(self.src_lens, batch_size)
class TranslationDataset(Seq2SeqDataset): class Seq2SeqDataset(AbstractSeq2SeqDataset):
"""A dataset that calls prepare_seq2seq_batch.""" """A dataset that calls prepare_seq2seq_batch."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if self.max_source_length != self.max_target_length:
warnings.warn(
f"Mbart is using sequence lengths {self.max_source_length}, {self.max_target_length}. "
f"Imbalanced sequence lengths may be undesired for translation tasks"
)
def __getitem__(self, index) -> Dict[str, str]: def __getitem__(self, index) -> Dict[str, str]:
index = index + 1 # linecache starts at 1 index = index + 1 # linecache starts at 1
source_line = self.prefix + linecache.getline(str(self.src_file), index).rstrip("\n") source_line = self.prefix + linecache.getline(str(self.src_file), index).rstrip("\n")
...@@ -169,6 +171,7 @@ class TranslationDataset(Seq2SeqDataset): ...@@ -169,6 +171,7 @@ class TranslationDataset(Seq2SeqDataset):
} }
def collate_fn(self, batch) -> Dict[str, torch.Tensor]: def collate_fn(self, batch) -> Dict[str, torch.Tensor]:
"""Call prepare_seq2seq_batch."""
batch_encoding = self.tokenizer.prepare_seq2seq_batch( batch_encoding = self.tokenizer.prepare_seq2seq_batch(
[x["src_texts"] for x in batch], [x["src_texts"] for x in batch],
src_lang=self.src_lang, src_lang=self.src_lang,
...@@ -176,6 +179,8 @@ class TranslationDataset(Seq2SeqDataset): ...@@ -176,6 +179,8 @@ class TranslationDataset(Seq2SeqDataset):
tgt_lang=self.tgt_lang, tgt_lang=self.tgt_lang,
max_length=self.max_source_length, max_length=self.max_source_length,
max_target_length=self.max_target_length, max_target_length=self.max_target_length,
return_tensors="pt",
add_prefix_space=self.add_prefix_space,
) )
return batch_encoding.data return batch_encoding.data
...@@ -276,7 +281,11 @@ def calculate_rouge(output_lns: List[str], reference_lns: List[str], use_stemmer ...@@ -276,7 +281,11 @@ def calculate_rouge(output_lns: List[str], reference_lns: List[str], use_stemmer
return {k: round(v.mid.fmeasure * 100, 4) for k, v in result.items()} return {k: round(v.mid.fmeasure * 100, 4) for k, v in result.items()}
# Utilities for freezing parameters and checking whether they are frozen
def freeze_params(model: nn.Module): def freeze_params(model: nn.Module):
"""Set requires_grad=False for each of model.parameters()"""
for par in model.parameters(): for par in model.parameters():
par.requires_grad = False par.requires_grad = False
......
...@@ -151,6 +151,9 @@ def _prepare_bart_decoder_inputs( ...@@ -151,6 +151,9 @@ def _prepare_bart_decoder_inputs(
decoder_padding_mask = make_padding_mask(decoder_input_ids, pad_token_id) decoder_padding_mask = make_padding_mask(decoder_input_ids, pad_token_id)
else: else:
decoder_padding_mask = invert_mask(decoder_padding_mask) decoder_padding_mask = invert_mask(decoder_padding_mask)
if decoder_padding_mask is not None and decoder_padding_mask.shape[1] > 1:
# never mask leading token, even if it is pad
decoder_padding_mask[:, 0] = decoder_padding_mask[:, 1]
causal_mask = torch.triu(fill_with_neg_inf(torch.zeros(tgt_len, tgt_len)), 1).to( causal_mask = torch.triu(fill_with_neg_inf(torch.zeros(tgt_len, tgt_len)), 1).to(
dtype=causal_mask_dtype, device=decoder_input_ids.device dtype=causal_mask_dtype, device=decoder_input_ids.device
) )
......
...@@ -636,7 +636,7 @@ class T5PreTrainedModel(PreTrainedModel): ...@@ -636,7 +636,7 @@ class T5PreTrainedModel(PreTrainedModel):
# replace possible -100 values in labels by `pad_token_id` # replace possible -100 values in labels by `pad_token_id`
shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
assert torch.all(shifted_input_ids >= 0).item(), "Verify that `labels` has only positive values and -100" assert torch.all(shifted_input_ids >= 0).item(), "Verify that `shifted_input_ids` has only positive values"
return shifted_input_ids return shifted_input_ids
......
...@@ -33,6 +33,7 @@ _all_bart_models = [ ...@@ -33,6 +33,7 @@ _all_bart_models = [
"facebook/bart-large-cnn", "facebook/bart-large-cnn",
"facebook/bart-large-xsum", "facebook/bart-large-xsum",
"yjernite/bart_eli5", "yjernite/bart_eli5",
# This is not exhaustive: see https://huggingface.co/models?filter=bart
] ]
...@@ -117,6 +118,8 @@ class BartTokenizer(RobertaTokenizer): ...@@ -117,6 +118,8 @@ class BartTokenizer(RobertaTokenizer):
The full set of keys ``[input_ids, attention_mask, decoder_input_ids, decoder_attention_mask]``, The full set of keys ``[input_ids, attention_mask, decoder_input_ids, decoder_attention_mask]``,
will only be returned if tgt_texts is passed. Otherwise, input_ids, attention_mask will be the only keys. will only be returned if tgt_texts is passed. Otherwise, input_ids, attention_mask will be the only keys.
""" """
kwargs.pop("src_lang", None)
kwargs.pop("tgt_lang", None)
if max_length is None: if max_length is None:
max_length = self.model_max_length max_length = self.model_max_length
model_inputs: BatchEncoding = self( model_inputs: BatchEncoding = self(
...@@ -133,7 +136,7 @@ class BartTokenizer(RobertaTokenizer): ...@@ -133,7 +136,7 @@ class BartTokenizer(RobertaTokenizer):
# Process tgt_texts # Process tgt_texts
if max_target_length is None: if max_target_length is None:
max_target_length = max_length max_target_length = max_length
decoder_inputs: BatchEncoding = self( labels = self(
tgt_texts, tgt_texts,
add_special_tokens=True, add_special_tokens=True,
return_tensors=return_tensors, return_tensors=return_tensors,
...@@ -141,10 +144,8 @@ class BartTokenizer(RobertaTokenizer): ...@@ -141,10 +144,8 @@ class BartTokenizer(RobertaTokenizer):
max_length=max_target_length, max_length=max_target_length,
truncation=truncation, truncation=truncation,
**kwargs, **kwargs,
) )["input_ids"]
for k, v in decoder_inputs.items(): model_inputs["labels"] = labels
model_inputs[f"decoder_{k}"] = v
return model_inputs return model_inputs
...@@ -245,7 +246,7 @@ class BartTokenizerFast(RobertaTokenizerFast): ...@@ -245,7 +246,7 @@ class BartTokenizerFast(RobertaTokenizerFast):
# Process tgt_texts # Process tgt_texts
if max_target_length is None: if max_target_length is None:
max_target_length = max_length max_target_length = max_length
decoder_inputs: BatchEncoding = self( labels = self(
tgt_texts, tgt_texts,
add_special_tokens=True, add_special_tokens=True,
return_tensors=return_tensors, return_tensors=return_tensors,
...@@ -253,8 +254,6 @@ class BartTokenizerFast(RobertaTokenizerFast): ...@@ -253,8 +254,6 @@ class BartTokenizerFast(RobertaTokenizerFast):
max_length=max_target_length, max_length=max_target_length,
truncation=truncation, truncation=truncation,
**kwargs, **kwargs,
) )["input_ids"]
for k, v in decoder_inputs.items(): model_inputs["labels"] = labels
model_inputs[f"decoder_{k}"] = v
return model_inputs return model_inputs
...@@ -160,9 +160,7 @@ class MarianTokenizer(PreTrainedTokenizer): ...@@ -160,9 +160,7 @@ class MarianTokenizer(PreTrainedTokenizer):
tokenizer_kwargs["max_length"] = max_target_length tokenizer_kwargs["max_length"] = max_target_length
self.current_spm = self.spm_target self.current_spm = self.spm_target
decoder_inputs: BatchEncoding = self(tgt_texts, **tokenizer_kwargs) model_inputs["labels"] = self(tgt_texts, **tokenizer_kwargs)["input_ids"]
for k, v in decoder_inputs.items():
model_inputs[f"decoder_{k}"] = v
self.current_spm = self.spm_source self.current_spm = self.spm_source
return model_inputs return model_inputs
......
...@@ -98,32 +98,6 @@ class MBartTokenizer(XLMRobertaTokenizer): ...@@ -98,32 +98,6 @@ class MBartTokenizer(XLMRobertaTokenizer):
self._additional_special_tokens = list(self.lang_code_to_id.keys()) self._additional_special_tokens = list(self.lang_code_to_id.keys())
self.set_src_lang_special_tokens(kwargs.get("src_lang", "en_XX")) self.set_src_lang_special_tokens(kwargs.get("src_lang", "en_XX"))
def build_inputs_with_special_tokens(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
) -> List[int]:
"""
Build model inputs from a sequence or a pair of sequence for sequence classification tasks
by concatenating and adding special tokens. The special tokens depend on calling set_lang.
An MBART sequence has the following format, where ``X`` represents the sequence:
- ``input_ids`` (for encoder) ``X [eos, src_lang_code]``
- ``decoder_input_ids``: (for decoder) ``[tgt_lang_code] X [eos]``
BOS is never used.
Pairs of sequences are not the expected use case, but they will be handled without a separator.
Args:
token_ids_0 (:obj:`List[int]`):
List of IDs to which the special tokens will be added
token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
Optional second list of IDs for sequence pairs.
Returns:
:obj:`List[int]`: list of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens.
"""
if token_ids_1 is None:
return self.prefix_tokens + token_ids_0 + self.suffix_tokens
# We don't expect to process pairs, but leave the pair logic for API consistency
return self.prefix_tokens + token_ids_0 + token_ids_1 + self.suffix_tokens
def get_special_tokens_mask( def get_special_tokens_mask(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
) -> List[int]: ) -> List[int]:
...@@ -156,6 +130,32 @@ class MBartTokenizer(XLMRobertaTokenizer): ...@@ -156,6 +130,32 @@ class MBartTokenizer(XLMRobertaTokenizer):
return prefix_ones + ([0] * len(token_ids_0)) + suffix_ones return prefix_ones + ([0] * len(token_ids_0)) + suffix_ones
return prefix_ones + ([0] * len(token_ids_0)) + ([0] * len(token_ids_1)) + suffix_ones return prefix_ones + ([0] * len(token_ids_0)) + ([0] * len(token_ids_1)) + suffix_ones
def build_inputs_with_special_tokens(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
) -> List[int]:
"""
Build model inputs from a sequence or a pair of sequence for sequence classification tasks
by concatenating and adding special tokens. The special tokens depend on calling set_lang.
An MBART sequence has the following format, where ``X`` represents the sequence:
- ``input_ids`` (for encoder) ``X [eos, src_lang_code]``
- ``decoder_input_ids``: (for decoder) ``[tgt_lang_code] X [eos]``
BOS is never used.
Pairs of sequences are not the expected use case, but they will be handled without a separator.
Args:
token_ids_0 (:obj:`List[int]`):
List of IDs to which the special tokens will be added
token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
Optional second list of IDs for sequence pairs.
Returns:
:obj:`List[int]`: list of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens.
"""
if token_ids_1 is None:
return self.prefix_tokens + token_ids_0 + self.suffix_tokens
# We don't expect to process pairs, but leave the pair logic for API consistency
return self.prefix_tokens + token_ids_0 + token_ids_1 + self.suffix_tokens
@add_start_docstrings_to_callable(PREPARE_SEQ2SEQ_BATCH_DOCSTRING) @add_start_docstrings_to_callable(PREPARE_SEQ2SEQ_BATCH_DOCSTRING)
def prepare_seq2seq_batch( def prepare_seq2seq_batch(
self, self,
...@@ -251,7 +251,8 @@ class MBartTokenizer(XLMRobertaTokenizer): ...@@ -251,7 +251,8 @@ class MBartTokenizer(XLMRobertaTokenizer):
if max_target_length is None: if max_target_length is None:
max_target_length = max_length max_target_length = max_length
self.set_tgt_lang_special_tokens(tgt_lang) self.set_tgt_lang_special_tokens(tgt_lang)
decoder_inputs: BatchEncoding = self(
labels = self(
tgt_texts, tgt_texts,
add_special_tokens=True, add_special_tokens=True,
return_tensors=return_tensors, return_tensors=return_tensors,
...@@ -259,10 +260,8 @@ class MBartTokenizer(XLMRobertaTokenizer): ...@@ -259,10 +260,8 @@ class MBartTokenizer(XLMRobertaTokenizer):
max_length=max_target_length, max_length=max_target_length,
truncation=True, truncation=True,
**kwargs, **kwargs,
) )["input_ids"]
for k, v in decoder_inputs.items(): model_inputs["labels"] = labels
model_inputs[f"decoder_{k}"] = v
self.set_src_lang_special_tokens(src_lang) # sets to src_lang self.set_src_lang_special_tokens(src_lang) # sets to src_lang
return model_inputs return model_inputs
...@@ -275,5 +274,5 @@ class MBartTokenizer(XLMRobertaTokenizer): ...@@ -275,5 +274,5 @@ class MBartTokenizer(XLMRobertaTokenizer):
def set_tgt_lang_special_tokens(self, lang: str) -> None: def set_tgt_lang_special_tokens(self, lang: str) -> None:
"""Reset the special tokens to the target language setting. Prefix [tgt_lang_code], suffix =[eos].""" """Reset the special tokens to the target language setting. Prefix [tgt_lang_code], suffix =[eos]."""
self.cur_lang_code = self.lang_code_to_id[lang] self.cur_lang_code = self.lang_code_to_id[lang]
self.prefix_tokens = [self.cur_lang_code] self.prefix_tokens = []
self.suffix_tokens = [self.eos_token_id] self.suffix_tokens = [self.eos_token_id, self.cur_lang_code]
...@@ -114,6 +114,7 @@ class PegasusTokenizer(ReformerTokenizer): ...@@ -114,6 +114,7 @@ class PegasusTokenizer(ReformerTokenizer):
return_tensors: str = "pt", return_tensors: str = "pt",
truncation=True, truncation=True,
padding="longest", padding="longest",
**unused,
) -> BatchEncoding: ) -> BatchEncoding:
""" """
Prepare model inputs for summarization or translation. Prepare model inputs for summarization or translation.
...@@ -133,7 +134,9 @@ class PegasusTokenizer(ReformerTokenizer): ...@@ -133,7 +134,9 @@ class PegasusTokenizer(ReformerTokenizer):
return model_inputs return model_inputs
if max_target_length is not None: if max_target_length is not None:
tokenizer_kwargs["max_length"] = max_target_length tokenizer_kwargs["max_length"] = max_target_length
decoder_inputs: BatchEncoding = self(tgt_texts, **tokenizer_kwargs) # TODO(@sshleifer): maybe tgt_texts = [self.pad_token + t for t in tgt_texts] # add decoder_start_token_id
for k, v in decoder_inputs.items(): labels: BatchEncoding = self(tgt_texts, **tokenizer_kwargs)["input_ids"]
model_inputs[f"decoder_{k}"] = v model_inputs["labels"] = labels
# for k, v in decoder_inputs.items():
# model_inputs[f"decoder_{k}"] = v
return model_inputs return model_inputs
...@@ -346,7 +346,7 @@ class T5Tokenizer(PreTrainedTokenizer): ...@@ -346,7 +346,7 @@ class T5Tokenizer(PreTrainedTokenizer):
if max_length is None: if max_length is None:
max_length = self.max_len max_length = self.max_len
self.prefix_tokens = [] self.prefix_tokens = []
model_inputs: BatchEncoding = self( model_inputs = self(
src_texts, src_texts,
add_special_tokens=True, add_special_tokens=True,
return_tensors=return_tensors, return_tensors=return_tensors,
...@@ -362,7 +362,7 @@ class T5Tokenizer(PreTrainedTokenizer): ...@@ -362,7 +362,7 @@ class T5Tokenizer(PreTrainedTokenizer):
max_target_length = max_length max_target_length = max_length
# set prefix_tokens for target text # set prefix_tokens for target text
self.prefix_tokens = [self.pad_token_id] self.prefix_tokens = [self.pad_token_id]
decoder_inputs: BatchEncoding = self( labels_and_decoder_mask = self(
tgt_texts, tgt_texts,
add_special_tokens=True, add_special_tokens=True,
return_tensors=return_tensors, return_tensors=return_tensors,
...@@ -371,8 +371,7 @@ class T5Tokenizer(PreTrainedTokenizer): ...@@ -371,8 +371,7 @@ class T5Tokenizer(PreTrainedTokenizer):
truncation=truncation, truncation=truncation,
**kwargs, **kwargs,
) )
for k, v in decoder_inputs.items(): model_inputs["labels"] = labels_and_decoder_mask["input_ids"]
model_inputs[f"decoder_{k}"] = v model_inputs["decoder_attention_mask"] = labels_and_decoder_mask["attention_mask"]
self.prefix_tokens = [] self.prefix_tokens = []
return model_inputs return model_inputs
...@@ -18,7 +18,7 @@ import unittest ...@@ -18,7 +18,7 @@ import unittest
import timeout_decorator # noqa import timeout_decorator # noqa
from transformers import BatchEncoding, is_torch_available from transformers import is_torch_available
from transformers.file_utils import cached_property from transformers.file_utils import cached_property
from transformers.testing_utils import require_torch, slow, torch_device from transformers.testing_utils import require_torch, slow, torch_device
...@@ -496,7 +496,7 @@ class BartModelIntegrationTests(unittest.TestCase): ...@@ -496,7 +496,7 @@ class BartModelIntegrationTests(unittest.TestCase):
def test_xsum_summarization_same_as_fairseq(self): def test_xsum_summarization_same_as_fairseq(self):
model = BartForConditionalGeneration.from_pretrained("facebook/bart-large-xsum").to(torch_device) model = BartForConditionalGeneration.from_pretrained("facebook/bart-large-xsum").to(torch_device)
self.assertFalse(model.config.is_valid_mbart()) self.assertFalse(model.config.is_valid_mbart())
tok = BartTokenizer.from_pretrained("facebook/bart-large") tok = self.default_tokenizer
EXPECTED_SUMMARY = "California's largest power company has begun shutting off electricity to thousands of customers in the state." EXPECTED_SUMMARY = "California's largest power company has begun shutting off electricity to thousands of customers in the state."
dct = tok.batch_encode_plus( dct = tok.batch_encode_plus(
...@@ -585,84 +585,6 @@ class BartModelIntegrationTests(unittest.TestCase): ...@@ -585,84 +585,6 @@ class BartModelIntegrationTests(unittest.TestCase):
# TODO(SS): run fairseq again with num_beams=2, min_len=20. # TODO(SS): run fairseq again with num_beams=2, min_len=20.
# TODO(SS): add test case that hits max_length # TODO(SS): add test case that hits max_length
def test_prepare_seq2seq_batch(self):
tokenizers = [self.default_tokenizer, self.default_tokenizer_fast]
src_text = ["A long paragraph for summrization.", "Another paragraph for summrization."]
tgt_text = [
"Summary of the text.",
"Another summary.",
]
expected_src_tokens = [0, 250, 251, 17818, 13, 32933, 21645, 1258, 4, 2]
for tokenizer in tokenizers:
batch = tokenizer.prepare_seq2seq_batch(
src_text, tgt_texts=tgt_text, max_length=len(expected_src_tokens), return_tensors="pt"
)
self.assertIsInstance(batch, BatchEncoding)
self.assertEqual((2, 10), batch.input_ids.shape)
self.assertEqual((2, 10), batch.attention_mask.shape)
result = batch.input_ids.tolist()[0]
self.assertListEqual(expected_src_tokens, result)
# Test that special tokens are reset
def test_empty_target_text(self):
tokenizers = [self.default_tokenizer, self.default_tokenizer_fast]
src_text = ["A long paragraph for summrization.", "Another paragraph for summrization."]
for tokenizer in tokenizers:
batch = tokenizer.prepare_seq2seq_batch(src_text, return_tensors="pt")
# check if input_ids are returned and no decoder_input_ids
self.assertIn("input_ids", batch)
self.assertIn("attention_mask", batch)
self.assertNotIn("decoder_input_ids", batch)
self.assertNotIn("decoder_attention_mask", batch)
def test_max_target_length(self):
tokenizers = [self.default_tokenizer, self.default_tokenizer_fast]
src_text = ["A long paragraph for summrization.", "Another paragraph for summrization."]
tgt_text = [
"Summary of the text.",
"Another summary.",
]
for tokenizer in tokenizers:
batch = tokenizer.prepare_seq2seq_batch(
src_text, tgt_texts=tgt_text, max_target_length=32, padding="max_length", return_tensors="pt"
)
self.assertEqual(32, batch["decoder_input_ids"].shape[1])
self.assertEqual(32, batch["decoder_attention_mask"].shape[1])
# test None max_target_length
batch = tokenizer.prepare_seq2seq_batch(
src_text, tgt_texts=tgt_text, max_length=32, padding="max_length", return_tensors="pt"
)
self.assertEqual(32, batch["decoder_input_ids"].shape[1])
self.assertEqual(32, batch["decoder_attention_mask"].shape[1])
def test_outputs_not_longer_than_maxlen(self):
tokenizers = [self.default_tokenizer, self.default_tokenizer_fast]
for tokenizer in tokenizers:
batch = tokenizer.prepare_seq2seq_batch(
["I am a small frog" * 1024, "I am a small frog"], return_tensors="pt"
)
self.assertIsInstance(batch, BatchEncoding)
self.assertEqual(batch.input_ids.shape, (2, 1024))
def test_special_tokens(self):
tokenizers = [self.default_tokenizer, self.default_tokenizer_fast]
src_text = ["A long paragraph for summrization."]
tgt_text = [
"Summary of the text.",
]
for tokenizer in tokenizers:
batch = tokenizer.prepare_seq2seq_batch(src_text, tgt_texts=tgt_text, return_tensors="pt")
input_ids = batch["input_ids"]
decoder_input_ids = batch["decoder_input_ids"]
self.assertTrue((input_ids[:, 0] == tokenizer.bos_token_id).all().item())
self.assertTrue((decoder_input_ids[:, 0] == tokenizer.bos_token_id).all().item())
self.assertTrue((input_ids[:, -1] == tokenizer.eos_token_id).all().item())
self.assertTrue((decoder_input_ids[:, -1] == tokenizer.eos_token_id).all().item())
@require_torch @require_torch
class TestSinusoidalPositionalEmbeddings(unittest.TestCase): class TestSinusoidalPositionalEmbeddings(unittest.TestCase):
......
import json
import os
import unittest
from transformers import BartTokenizer, BartTokenizerFast, BatchEncoding
from transformers.file_utils import cached_property
from transformers.testing_utils import require_torch
from transformers.tokenization_roberta import VOCAB_FILES_NAMES
from .test_tokenization_common import TokenizerTesterMixin
class TestTokenizationBart(TokenizerTesterMixin, unittest.TestCase):
tokenizer_class = BartTokenizer
def setUp(self):
super().setUp()
vocab = [
"l",
"o",
"w",
"e",
"r",
"s",
"t",
"i",
"d",
"n",
"\u0120",
"\u0120l",
"\u0120n",
"\u0120lo",
"\u0120low",
"er",
"\u0120lowest",
"\u0120newer",
"\u0120wider",
"<unk>",
]
vocab_tokens = dict(zip(vocab, range(len(vocab))))
merges = ["#version: 0.2", "\u0120 l", "\u0120l o", "\u0120lo w", "e r", ""]
self.special_tokens_map = {"unk_token": "<unk>"}
self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES["vocab_file"])
self.merges_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES["merges_file"])
with open(self.vocab_file, "w", encoding="utf-8") as fp:
fp.write(json.dumps(vocab_tokens) + "\n")
with open(self.merges_file, "w", encoding="utf-8") as fp:
fp.write("\n".join(merges))
def get_tokenizer(self, **kwargs):
kwargs.update(self.special_tokens_map)
return self.tokenizer_class.from_pretrained(self.tmpdirname, **kwargs)
def get_rust_tokenizer(self, **kwargs):
kwargs.update(self.special_tokens_map)
return BartTokenizerFast.from_pretrained(self.tmpdirname, **kwargs)
def get_input_output_texts(self, tokenizer):
return "lower newer", "lower newer"
@cached_property
def default_tokenizer(self):
return BartTokenizer.from_pretrained("facebook/bart-large")
@cached_property
def default_tokenizer_fast(self):
return BartTokenizerFast.from_pretrained("facebook/bart-large")
@require_torch
def test_prepare_seq2seq_batch(self):
src_text = ["A long paragraph for summrization.", "Another paragraph for summrization."]
tgt_text = [
"Summary of the text.",
"Another summary.",
]
expected_src_tokens = [0, 250, 251, 17818, 13, 32933, 21645, 1258, 4, 2]
for tokenizer in [self.default_tokenizer, self.default_tokenizer_fast]:
batch = tokenizer.prepare_seq2seq_batch(
src_text, tgt_texts=tgt_text, max_length=len(expected_src_tokens), return_tensors="pt"
)
self.assertIsInstance(batch, BatchEncoding)
self.assertEqual((2, 10), batch.input_ids.shape)
self.assertEqual((2, 10), batch.attention_mask.shape)
result = batch.input_ids.tolist()[0]
self.assertListEqual(expected_src_tokens, result)
# Test that special tokens are reset
# Test Prepare Seq
@require_torch
def test_seq2seq_batch_empty_target_text(self):
src_text = ["A long paragraph for summrization.", "Another paragraph for summrization."]
for tokenizer in [self.default_tokenizer, self.default_tokenizer_fast]:
batch = tokenizer.prepare_seq2seq_batch(src_text, return_tensors="pt")
# check if input_ids are returned and no labels
self.assertIn("input_ids", batch)
self.assertIn("attention_mask", batch)
self.assertNotIn("labels", batch)
self.assertNotIn("decoder_attention_mask", batch)
@require_torch
def test_seq2seq_batch_max_target_length(self):
src_text = ["A long paragraph for summrization.", "Another paragraph for summrization."]
tgt_text = [
"Summary of the text.",
"Another summary.",
]
for tokenizer in [self.default_tokenizer, self.default_tokenizer_fast]:
batch = tokenizer.prepare_seq2seq_batch(
src_text, tgt_texts=tgt_text, max_target_length=32, padding="max_length", return_tensors="pt"
)
self.assertEqual(32, batch["labels"].shape[1])
# test None max_target_length
batch = tokenizer.prepare_seq2seq_batch(
src_text, tgt_texts=tgt_text, max_length=32, padding="max_length", return_tensors="pt"
)
self.assertEqual(32, batch["labels"].shape[1])
@require_torch
def test_seq2seq_batch_not_longer_than_maxlen(self):
for tokenizer in [self.default_tokenizer, self.default_tokenizer_fast]:
batch = tokenizer.prepare_seq2seq_batch(
["I am a small frog" * 1024, "I am a small frog"], return_tensors="pt"
)
self.assertIsInstance(batch, BatchEncoding)
self.assertEqual(batch.input_ids.shape, (2, 1024))
@require_torch
def test_special_tokens(self):
src_text = ["A long paragraph for summrization."]
tgt_text = [
"Summary of the text.",
]
for tokenizer in [self.default_tokenizer, self.default_tokenizer_fast]:
batch = tokenizer.prepare_seq2seq_batch(src_text, tgt_texts=tgt_text, return_tensors="pt")
input_ids = batch["input_ids"]
labels = batch["labels"]
self.assertTrue((input_ids[:, 0] == tokenizer.bos_token_id).all().item())
self.assertTrue((labels[:, 0] == tokenizer.bos_token_id).all().item())
self.assertTrue((input_ids[:, -1] == tokenizer.eos_token_id).all().item())
self.assertTrue((labels[:, -1] == tokenizer.eos_token_id).all().item())
...@@ -1555,14 +1555,19 @@ class TokenizerTesterMixin: ...@@ -1555,14 +1555,19 @@ class TokenizerTesterMixin:
"vor face decât să înrăutăţească violenţele şi mizeria pentru milioane de oameni.", "vor face decât să înrăutăţească violenţele şi mizeria pentru milioane de oameni.",
] ]
batch = tokenizer.prepare_seq2seq_batch( batch = tokenizer.prepare_seq2seq_batch(
src_texts=src_text, tgt_texts=tgt_text, max_length=3, max_target_length=10, return_tensors="pt" src_texts=src_text,
tgt_texts=tgt_text,
max_length=3,
max_target_length=10,
return_tensors="pt",
src_lang="en_XX", # this should be ignored (for all but mbart) but not cause an error
) )
self.assertEqual(batch.input_ids.shape[1], 3) self.assertEqual(batch.input_ids.shape[1], 3)
self.assertEqual(batch.decoder_input_ids.shape[1], 10) self.assertEqual(batch.labels.shape[1], 10)
# max_target_length will default to max_length if not specified # max_target_length will default to max_length if not specified
batch = tokenizer.prepare_seq2seq_batch(src_text, tgt_texts=tgt_text, max_length=3) batch = tokenizer.prepare_seq2seq_batch(src_text, tgt_texts=tgt_text, max_length=3)
self.assertEqual(batch.input_ids.shape[1], 3) self.assertEqual(batch.input_ids.shape[1], 3)
self.assertEqual(batch.decoder_input_ids.shape[1], 3) self.assertEqual(batch.labels.shape[1], 3)
batch_encoder_only = tokenizer.prepare_seq2seq_batch( batch_encoder_only = tokenizer.prepare_seq2seq_batch(
src_texts=src_text, max_length=3, max_target_length=10, return_tensors="pt" src_texts=src_text, max_length=3, max_target_length=10, return_tensors="pt"
......
import tempfile import tempfile
import unittest import unittest
from transformers import AutoTokenizer, BatchEncoding, MBartTokenizer from transformers import AutoTokenizer, BatchEncoding, MBartTokenizer, is_torch_available
from transformers.testing_utils import require_torch from transformers.testing_utils import require_torch
from .test_tokenization_common import TokenizerTesterMixin from .test_tokenization_common import TokenizerTesterMixin
from .test_tokenization_xlm_roberta import SAMPLE_VOCAB, SPIECE_UNDERLINE from .test_tokenization_xlm_roberta import SAMPLE_VOCAB, SPIECE_UNDERLINE
if is_torch_available():
from transformers.modeling_bart import shift_tokens_right
EN_CODE = 250004 EN_CODE = 250004
RO_CODE = 250020 RO_CODE = 250020
...@@ -123,35 +126,6 @@ class MBartEnroIntegrationTest(unittest.TestCase): ...@@ -123,35 +126,6 @@ class MBartEnroIntegrationTest(unittest.TestCase):
self.assertEqual(self.tokenizer.fairseq_tokens_to_ids["en_EN"], 250004) self.assertEqual(self.tokenizer.fairseq_tokens_to_ids["en_EN"], 250004)
self.assertEqual(self.tokenizer.fairseq_tokens_to_ids["ro_RO"], 250020) self.assertEqual(self.tokenizer.fairseq_tokens_to_ids["ro_RO"], 250020)
def test_enro_tokenizer_prepare_seq2seq_batch(self):
batch = self.tokenizer.prepare_seq2seq_batch(
self.src_text,
tgt_texts=self.tgt_text,
max_length=len(self.expected_src_tokens),
)
self.assertIsInstance(batch, BatchEncoding)
self.assertEqual((2, 14), batch.input_ids.shape)
self.assertEqual((2, 14), batch.attention_mask.shape)
result = batch.input_ids.tolist()[0]
self.assertListEqual(self.expected_src_tokens, result)
self.assertEqual(2, batch.decoder_input_ids[0, -1]) # EOS
# Test that special tokens are reset
self.assertEqual(self.tokenizer.prefix_tokens, [])
self.assertEqual(self.tokenizer.suffix_tokens, [self.tokenizer.eos_token_id, EN_CODE])
def test_max_target_length(self):
batch = self.tokenizer.prepare_seq2seq_batch(
self.src_text, tgt_texts=self.tgt_text, max_length=3, max_target_length=10
)
self.assertEqual(batch.input_ids.shape[1], 3)
self.assertEqual(batch.decoder_input_ids.shape[1], 10)
# max_target_length will default to max_length if not specified
batch = self.tokenizer.prepare_seq2seq_batch(self.src_text, tgt_texts=self.tgt_text, max_length=3)
self.assertEqual(batch.input_ids.shape[1], 3)
self.assertEqual(batch.decoder_input_ids.shape[1], 3)
def test_enro_tokenizer_batch_encode_plus(self): def test_enro_tokenizer_batch_encode_plus(self):
ids = self.tokenizer.batch_encode_plus(self.src_text).input_ids[0] ids = self.tokenizer.batch_encode_plus(self.src_text).input_ids[0]
self.assertListEqual(self.expected_src_tokens, ids) self.assertListEqual(self.expected_src_tokens, ids)
...@@ -169,7 +143,9 @@ class MBartEnroIntegrationTest(unittest.TestCase): ...@@ -169,7 +143,9 @@ class MBartEnroIntegrationTest(unittest.TestCase):
assert isinstance(src_text[0], str) assert isinstance(src_text[0], str)
desired_max_length = 10 desired_max_length = 10
ids = self.tokenizer.prepare_seq2seq_batch( ids = self.tokenizer.prepare_seq2seq_batch(
src_text, return_tensors=None, max_length=desired_max_length src_text,
return_tensors=None,
max_length=desired_max_length,
).input_ids[0] ).input_ids[0]
self.assertEqual(ids[-2], 2) self.assertEqual(ids[-2], 2)
self.assertEqual(ids[-1], EN_CODE) self.assertEqual(ids[-1], EN_CODE)
...@@ -184,3 +160,53 @@ class MBartEnroIntegrationTest(unittest.TestCase): ...@@ -184,3 +160,53 @@ class MBartEnroIntegrationTest(unittest.TestCase):
self.tokenizer.save_pretrained(tmpdirname) self.tokenizer.save_pretrained(tmpdirname)
new_tok = MBartTokenizer.from_pretrained(tmpdirname) new_tok = MBartTokenizer.from_pretrained(tmpdirname)
self.assertDictEqual(new_tok.fairseq_tokens_to_ids, original_special_tokens) self.assertDictEqual(new_tok.fairseq_tokens_to_ids, original_special_tokens)
# prepare_seq2seq_batch tests below
@require_torch
def test_batch_fairseq_parity(self):
batch: BatchEncoding = self.tokenizer.prepare_seq2seq_batch(
self.src_text, tgt_texts=self.tgt_text, return_tensors="pt"
)
batch["decoder_input_ids"] = shift_tokens_right(batch.labels, self.tokenizer.pad_token_id)
for k in batch:
batch[k] = batch[k].tolist()
# batch = {k: v.tolist() for k,v in batch.items()}
# fairseq batch: https://gist.github.com/sshleifer/cba08bc2109361a74ac3760a7e30e4f4
# batch.decoder_inputs_ids[0][0] ==
assert batch.input_ids[1][-2:] == [2, EN_CODE]
assert batch.decoder_input_ids[1][0] == RO_CODE
assert batch.decoder_input_ids[1][-1] == 2
assert batch.labels[1][-2:] == [2, RO_CODE]
@require_torch
def test_enro_tokenizer_prepare_seq2seq_batch(self):
batch = self.tokenizer.prepare_seq2seq_batch(
self.src_text,
tgt_texts=self.tgt_text,
max_length=len(self.expected_src_tokens),
)
batch["decoder_input_ids"] = shift_tokens_right(batch.labels, self.tokenizer.pad_token_id)
self.assertIsInstance(batch, BatchEncoding)
self.assertEqual((2, 14), batch.input_ids.shape)
self.assertEqual((2, 14), batch.attention_mask.shape)
result = batch.input_ids.tolist()[0]
self.assertListEqual(self.expected_src_tokens, result)
self.assertEqual(2, batch.decoder_input_ids[0, -1]) # EOS
# Test that special tokens are reset
self.assertEqual(self.tokenizer.prefix_tokens, [])
self.assertEqual(self.tokenizer.suffix_tokens, [self.tokenizer.eos_token_id, EN_CODE])
def test_seq2seq_max_target_length(self):
batch = self.tokenizer.prepare_seq2seq_batch(
self.src_text, tgt_texts=self.tgt_text, max_length=3, max_target_length=10
)
batch["decoder_input_ids"] = shift_tokens_right(batch.labels, self.tokenizer.pad_token_id)
self.assertEqual(batch.input_ids.shape[1], 3)
self.assertEqual(batch.decoder_input_ids.shape[1], 10)
# max_target_length will default to max_length if not specified
batch = self.tokenizer.prepare_seq2seq_batch(self.src_text, tgt_texts=self.tgt_text, max_length=3)
batch["decoder_input_ids"] = shift_tokens_right(batch.labels, self.tokenizer.pad_token_id)
self.assertEqual(batch.input_ids.shape[1], 3)
self.assertEqual(batch.decoder_input_ids.shape[1], 3)
...@@ -63,7 +63,6 @@ class PegasusTokenizationTest(TokenizerTesterMixin, unittest.TestCase): ...@@ -63,7 +63,6 @@ class PegasusTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
batch = self.pegasus_large_tokenizer.prepare_seq2seq_batch(src_texts, tgt_texts=tgt_texts, max_target_length=5) batch = self.pegasus_large_tokenizer.prepare_seq2seq_batch(src_texts, tgt_texts=tgt_texts, max_target_length=5)
assert batch.input_ids.shape == (2, 1024) assert batch.input_ids.shape == (2, 1024)
assert batch.attention_mask.shape == (2, 1024) assert batch.attention_mask.shape == (2, 1024)
assert "decoder_input_ids" in batch # because tgt_texts was specified assert "labels" in batch # because tgt_texts was specified
assert batch.decoder_input_ids.shape == (2, 5) assert batch.labels.shape == (2, 5)
assert batch.decoder_attention_mask.shape == (2, 5) assert len(batch) == 3 # input_ids, attention_mask, labels. Other things make by BartModel
assert len(batch) == 4 # no extra keys
...@@ -66,7 +66,7 @@ class RobertaTokenizationTest(TokenizerTesterMixin, unittest.TestCase): ...@@ -66,7 +66,7 @@ class RobertaTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
def get_tokenizer(self, **kwargs): def get_tokenizer(self, **kwargs):
kwargs.update(self.special_tokens_map) kwargs.update(self.special_tokens_map)
return RobertaTokenizer.from_pretrained(self.tmpdirname, **kwargs) return self.tokenizer_class.from_pretrained(self.tmpdirname, **kwargs)
def get_rust_tokenizer(self, **kwargs): def get_rust_tokenizer(self, **kwargs):
kwargs.update(self.special_tokens_map) kwargs.update(self.special_tokens_map)
...@@ -78,7 +78,7 @@ class RobertaTokenizationTest(TokenizerTesterMixin, unittest.TestCase): ...@@ -78,7 +78,7 @@ class RobertaTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
return input_text, output_text return input_text, output_text
def test_full_tokenizer(self): def test_full_tokenizer(self):
tokenizer = RobertaTokenizer(self.vocab_file, self.merges_file, **self.special_tokens_map) tokenizer = self.tokenizer_class(self.vocab_file, self.merges_file, **self.special_tokens_map)
text = "lower newer" text = "lower newer"
bpe_tokens = ["l", "o", "w", "er", "\u0120", "n", "e", "w", "er"] bpe_tokens = ["l", "o", "w", "er", "\u0120", "n", "e", "w", "er"]
tokens = tokenizer.tokenize(text) # , add_prefix_space=True) tokens = tokenizer.tokenize(text) # , add_prefix_space=True)
...@@ -99,7 +99,7 @@ class RobertaTokenizationTest(TokenizerTesterMixin, unittest.TestCase): ...@@ -99,7 +99,7 @@ class RobertaTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
@slow @slow
def test_sequence_builders(self): def test_sequence_builders(self):
tokenizer = RobertaTokenizer.from_pretrained("roberta-base") tokenizer = self.tokenizer_class.from_pretrained("roberta-base")
text = tokenizer.encode("sequence builders", add_special_tokens=False) text = tokenizer.encode("sequence builders", add_special_tokens=False)
text_2 = tokenizer.encode("multi-sequence build", add_special_tokens=False) text_2 = tokenizer.encode("multi-sequence build", add_special_tokens=False)
...@@ -137,7 +137,7 @@ class RobertaTokenizationTest(TokenizerTesterMixin, unittest.TestCase): ...@@ -137,7 +137,7 @@ class RobertaTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
first_char = tokenizer.convert_ids_to_tokens(encoded[1])[0] first_char = tokenizer.convert_ids_to_tokens(encoded[1])[0]
self.assertNotEqual(first_char, space_encoding) self.assertNotEqual(first_char, space_encoding)
# Testing spaces after special tokenss # Testing spaces after special tokens
mask = "<mask>" mask = "<mask>"
tokenizer.add_special_tokens( tokenizer.add_special_tokens(
{"mask_token": AddedToken(mask, lstrip=True, rstrip=False)} {"mask_token": AddedToken(mask, lstrip=True, rstrip=False)}
......
...@@ -153,7 +153,7 @@ class T5TokenizationTest(TokenizerTesterMixin, unittest.TestCase): ...@@ -153,7 +153,7 @@ class T5TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
def test_max_target_length(self): def test_max_target_length(self):
tokenizer = self.t5_base_tokenizer tokenizer = self.t5_base_tokenizer
src_text = ["A long paragraph for summrization.", "Another paragraph for summrization."] src_text = ["A short paragraph for summrization.", "Another short paragraph for summrization."]
tgt_text = [ tgt_text = [
"Summary of the text.", "Summary of the text.",
"Another summary.", "Another summary.",
...@@ -161,14 +161,14 @@ class T5TokenizationTest(TokenizerTesterMixin, unittest.TestCase): ...@@ -161,14 +161,14 @@ class T5TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
batch = tokenizer.prepare_seq2seq_batch( batch = tokenizer.prepare_seq2seq_batch(
src_text, tgt_texts=tgt_text, max_target_length=32, padding="max_length", return_tensors=FRAMEWORK src_text, tgt_texts=tgt_text, max_target_length=32, padding="max_length", return_tensors=FRAMEWORK
) )
self.assertEqual(32, batch["decoder_input_ids"].shape[1]) self.assertEqual(32, batch["labels"].shape[1])
self.assertEqual(32, batch["decoder_attention_mask"].shape[1]) self.assertEqual(32, batch["decoder_attention_mask"].shape[1])
# test None max_target_length # test None max_target_length
batch = tokenizer.prepare_seq2seq_batch( batch = tokenizer.prepare_seq2seq_batch(
src_text, tgt_texts=tgt_text, max_length=32, padding="max_length", return_tensors=FRAMEWORK src_text, tgt_texts=tgt_text, max_length=32, padding="max_length", return_tensors=FRAMEWORK
) )
self.assertEqual(32, batch["decoder_input_ids"].shape[1]) self.assertEqual(32, batch["labels"].shape[1])
self.assertEqual(32, batch["decoder_attention_mask"].shape[1]) self.assertEqual(32, batch["decoder_attention_mask"].shape[1])
def test_outputs_not_longer_than_maxlen(self): def test_outputs_not_longer_than_maxlen(self):
...@@ -190,7 +190,7 @@ class T5TokenizationTest(TokenizerTesterMixin, unittest.TestCase): ...@@ -190,7 +190,7 @@ class T5TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
batch = tokenizer.prepare_seq2seq_batch(src_text, tgt_texts=tgt_text, return_tensors=FRAMEWORK) batch = tokenizer.prepare_seq2seq_batch(src_text, tgt_texts=tgt_text, return_tensors=FRAMEWORK)
src_ids = list(batch.input_ids.numpy()[0]) src_ids = list(batch.input_ids.numpy()[0])
tgt_ids = list(batch.decoder_input_ids.numpy()[0]) tgt_ids = list(batch.labels.numpy()[0])
self.assertEqual(expected_src_tokens, src_ids) self.assertEqual(expected_src_tokens, src_ids)
self.assertEqual(expected_tgt_tokens, tgt_ids) self.assertEqual(expected_tgt_tokens, tgt_ids)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment