Unverified Commit 9336086a authored by Sam Shleifer's avatar Sam Shleifer Committed by GitHub
Browse files

prepare_seq2seq_batch makes labels/ decoder_input_ids made later. (#6654)

* broken test

* batch parity

* tests pass

* boom boom

* boom boom

* split out bart tokenizer tests

* fix tests

* boom boom

* Fixed dataset bug

* Fix marian

* Undo extra

* Get marian working

* Fix t5 tok tests

* Test passing

* Cleanup

* better assert msg

* require torch

* Fix mbart tests

* undo extra decoder_attn_mask change

* Fix import

* pegasus tokenizer can ignore src_lang kwargs

* unused kwarg test cov

* boom boom

* add todo for pegasus issue

* cover one word translation edge case

* Cleanup

* doc
parent cb276b41
......@@ -71,8 +71,8 @@ Summarization Tips:
(It rarely makes sense to start from `bart-large` unless you are a researching finetuning methods).
**Update 2018-07-18**
Datasets: `Seq2SeqDataset` should be used for all tokenizers without a `prepare_seq2seq_batch` method. For those who do (like Marian, MBart), `TranslationDataset` should be used.**
A new dataset is needed to support multilingual tasks.
Datasets: `LegacySeq2SeqDataset` will be used for all tokenizers without a `prepare_seq2seq_batch` method. Otherwise, `Seq2SeqDataset` will be used.
Future work/help wanted: A new dataset to support multilingual tasks.
### Command Line Options
......@@ -106,7 +106,7 @@ The following command should work on a 16GB GPU:
--train_batch_size=1 \
--eval_batch_size=1 \
--output_dir=xsum_results \
--num_train_epochs 1 \
--num_train_epochs 6 \
--model_name_or_path facebook/bart-large
```
......
import argparse
import gc
import os
import warnings
from pathlib import Path
from typing import List
......@@ -11,6 +12,7 @@ from torch.nn import functional as F
from lightning_base import generic_train
from transformers import AutoModelForSeq2SeqLM, MBartTokenizer, T5Config, T5ForConditionalGeneration
from transformers.modeling_bart import shift_tokens_right
try:
......@@ -22,6 +24,7 @@ try:
assert_all_frozen,
calculate_bleu,
freeze_params,
label_smoothed_nll_loss,
pickle_load,
use_task_specific_params,
)
......@@ -34,12 +37,15 @@ except ImportError:
assert_all_frozen,
calculate_bleu,
freeze_params,
label_smoothed_nll_loss,
pickle_load,
use_task_specific_params,
)
class BartSummarizationDistiller(SummarizationModule):
"""Supports Bart, Pegasus and other models that inherit from Bart."""
loss_names = ["loss", "ce_loss", "mlm_loss", "enc_mse_loss", "hid_loss_enc", "hid_loss_dec"]
def __init__(self, hparams):
......@@ -160,22 +166,32 @@ class BartSummarizationDistiller(SummarizationModule):
def _step(self, batch):
# assert is_frozen(self.teacher)
pad_token_id = self.tokenizer.pad_token_id
input_ids, src_mask, y = batch["input_ids"], batch["attention_mask"], batch["decoder_input_ids"]
decoder_input_ids = y[:, :-1].contiguous()
labels = y[:, 1:].clone()
labels[y[:, 1:] == pad_token_id] = -100
input_ids, src_mask, tgt_ids = batch["input_ids"], batch["attention_mask"], batch["labels"]
decoder_input_ids = shift_tokens_right(tgt_ids, pad_token_id)
# noinspection PyCallingNonCallable
sloss, slogits, dec_hidden, enc_outputs, enc_hidden_state = self(
lm_logits, dec_hidden, enc_outputs, enc_hidden_state = self(
input_ids,
attention_mask=src_mask,
decoder_input_ids=decoder_input_ids,
labels=labels,
output_hidden_states=True,
output_attentions=False,
)
use_cache=False,
) # TODO(@sshleifer): return_dict=True cleanup
# Same cross entropy vs. label smoothing logic as finetune.py
assert lm_logits.shape[-1] == self.model.config.vocab_size
if self.hparams.label_smoothing == 0:
# Same behavior as modeling_bart.py, besides ignoring pad_token_id
loss_fct = torch.nn.CrossEntropyLoss(ignore_index=pad_token_id)
student_lm_loss = loss_fct(lm_logits.view(-1, lm_logits.shape[-1]), tgt_ids.view(-1))
else:
lprobs = torch.nn.functional.log_softmax(lm_logits, dim=-1)
student_lm_loss, _ = label_smoothed_nll_loss(
lprobs, tgt_ids, self.hparams.label_smoothing, ignore_index=pad_token_id
)
def zero_tensor():
return torch.tensor(0.0).type_as(sloss)
return torch.tensor(0.0).type_as(student_lm_loss)
loss_encoder, hid_loss_enc, hid_loss_dec = zero_tensor(), zero_tensor(), zero_tensor()
if self.different_encoder:
......@@ -199,29 +215,26 @@ class BartSummarizationDistiller(SummarizationModule):
attention_mask=src_mask,
encoder_outputs=teacher_enc_outputs,
decoder_input_ids=decoder_input_ids,
lm_labels=labels,
lm_labels=tgt_ids,
output_hidden_states=True,
)
dec_mask = decoder_input_ids.ne(pad_token_id)
loss_ce, s_logits_slct, t_logits_slct = self.calc_ce_loss(dec_mask, slogits, tlogits)
loss_ce, s_logits_slct, t_logits_slct = self.calc_ce_loss(dec_mask, lm_logits, tlogits)
if self.alpha_hid > 0:
hid_loss_dec = self.calc_hidden_loss(dec_mask, dec_hidden, tdec_hidden, self.hparams.d_layer_to_copy)
blended_loss = (
self.alpha_ce * loss_ce
+ self.alpha_mlm * sloss
+ self.alpha_mlm * student_lm_loss
+ self.hparams.alpha_encoder_loss * loss_encoder
+ self.hparams.alpha_hid * (hid_loss_enc + hid_loss_dec)
)
return blended_loss, loss_ce, sloss, loss_encoder, hid_loss_enc, hid_loss_dec
return blended_loss, loss_ce, student_lm_loss, loss_encoder, hid_loss_enc, hid_loss_dec
def calc_hidden_loss(self, attention_mask, hidden_states, hidden_states_T, matches):
assert not isinstance(
hidden_states, torch.Tensor
), f"expected list or tuple for hidden_states, got tensor of shape {hidden_states.shape}"
assert not isinstance(
hidden_states_T, torch.Tensor
), f"expected list or tuple for hidden_states_T, got tensor of shape {hidden_states_T.shape}"
msg = "expected list or tuple for hidden_states, got tensor of shape: "
assert not isinstance(hidden_states, torch.Tensor), f"{msg}{hidden_states.shape}"
assert not isinstance(hidden_states_T, torch.Tensor), f"{msg}{hidden_states_T.shape}"
mask = attention_mask.to(hidden_states[0])
valid_count = mask.sum() * hidden_states[0].size(-1)
hidden_losses = [
......@@ -233,7 +246,7 @@ class BartSummarizationDistiller(SummarizationModule):
def add_distill_args(parser):
parser.add_argument("--teacher", default="facebook/bart-large-cnn", type=str)
parser.add_argument("--teacher", type=str)
parser.add_argument("--alpha_ce", default=0.8, type=float)
parser.add_argument("--alpha_mlm", default=0.2, type=float)
parser.add_argument("--alpha_encoder_loss", default=0.0, type=float)
......@@ -245,8 +258,9 @@ def add_distill_args(parser):
class BartTranslationDistiller(BartSummarizationDistiller):
"""Supports Mbart, Marian, other models that inherit from Bart."""
mode = "translation"
loss_names = ["loss"]
metric_names = ["bleu"]
val_metric = "bleu"
......@@ -368,7 +382,7 @@ class T5SummarizationDistiller(BartSummarizationDistiller):
attention_mask=source_mask,
encoder_outputs=teacher_enc_outputs,
decoder_input_ids=decoder_input_ids,
lm_labels=labels,
labels=labels,
output_hidden_states=True,
use_cache=False,
)
......@@ -402,6 +416,7 @@ def create_module(args):
def evaluate_checkpoint(ckpt_path: Path, dest_dir=None):
# TODO(SS): DELETE?
exp_dir = ckpt_path.parent
if dest_dir is None:
dest_dir = exp_dir
......@@ -424,33 +439,40 @@ def evaluate_checkpoint(ckpt_path: Path, dest_dir=None):
trainer.test(model)
def get_layers_to_copy(n_to_get, tot):
all_layers = list(range(tot))
if tot == 12: # Alternating for special cases
layers_to_copy = { # maps num layers in student -> which teacher layers to copy
1: [0],
2: [0, 6],
3: [0, 6, 11],
4: [0, 4, 8, 11],
6: [0, 2, 4, 7, 9, 11],
9: [0, 1, 2, 4, 5, 7, 9, 10, 11],
12: all_layers,
}
return layers_to_copy[n_to_get]
elif tot == 16:
layers_to_copy = { # maps num layers in student -> which teacher layers to copy
1: [0],
2: [0, 8],
3: [0, 8, 15],
4: [0, 5, 10, 15],
6: [0, 3, 6, 9, 12, 15],
8: [0, 2, 4, 6, 8, 10, 12, 15],
9: [0, 1, 3, 5, 7, 9, 11, 13, 15],
16: all_layers,
}
return layers_to_copy[n_to_get]
else:
return all_layers[:n_to_get] # TODO: better version on theseus-bart branch
LAYERS_TO_COPY = {
# maps num layers in student -> which teacher layers to copy.
# 12: bart, 16: pegasus, 6: marian/Helsinki-NLP
12: {
1: [0],
2: [0, 6],
3: [0, 6, 11],
4: [0, 4, 8, 11],
6: [0, 2, 4, 7, 9, 11],
9: [0, 1, 2, 4, 5, 7, 9, 10, 11],
12: list(range(12)),
},
16: { # maps num layers in student -> which teacher layers to copy
1: [0],
2: [0, 8],
3: [0, 8, 15],
4: [0, 5, 10, 15],
6: [0, 3, 6, 9, 12, 15],
8: [0, 2, 4, 6, 8, 10, 12, 15],
9: [0, 1, 3, 5, 7, 9, 11, 13, 15],
16: list(range(16)),
},
6: {1: [0], 2: [0, 5], 3: [0, 2, 5], 4: [0, 1, 3, 5], 6: list(range(6))},
}
def get_layers_to_copy(n_student, n_teacher):
try:
return LAYERS_TO_COPY[n_teacher][n_student]
except KeyError:
warnings.warn(
f"no hardcoded layers to copy for teacher {n_teacher} -> student {n_student}, defaulting to first {n_student}"
)
return list(range(n_student))
def distill_main(args):
......
......@@ -13,15 +13,16 @@ import torch
from torch.utils.data import DataLoader
from lightning_base import BaseTransformer, add_generic_args, generic_train
from transformers import MarianTokenizer, MBartTokenizer, T5ForConditionalGeneration
from transformers import MBartTokenizer, T5ForConditionalGeneration
from transformers.modeling_bart import shift_tokens_right
try:
from .callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback, get_early_stopping_callback
from .utils import (
ROUGE_KEYS,
LegacySeq2SeqDataset,
Seq2SeqDataset,
TranslationDataset,
assert_all_frozen,
calculate_bleu,
calculate_rouge,
......@@ -39,8 +40,8 @@ except ImportError:
from callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback, get_early_stopping_callback
from utils import (
ROUGE_KEYS,
LegacySeq2SeqDataset,
Seq2SeqDataset,
TranslationDataset,
assert_all_frozen,
calculate_bleu,
calculate_rouge,
......@@ -102,14 +103,13 @@ class SummarizationModule(BaseTransformer):
self.hparams.git_sha = get_git_info()["repo_sha"]
self.num_workers = hparams.num_workers
self.decoder_start_token_id = None
self.decoder_start_token_id = None # default to config
if self.model.config.decoder_start_token_id is None and isinstance(self.tokenizer, MBartTokenizer):
self.decoder_start_token_id = self.tokenizer.lang_code_to_id[hparams.tgt_lang]
self.model.config.decoder_start_token_id = self.decoder_start_token_id
if isinstance(self.tokenizer, MBartTokenizer) or isinstance(self.tokenizer, MarianTokenizer):
self.dataset_class = TranslationDataset
else:
self.dataset_class = Seq2SeqDataset
self.dataset_class = (
Seq2SeqDataset if hasattr(self.tokenizer, "prepare_seq2seq_batch") else LegacySeq2SeqDataset
)
def freeze_embeds(self):
"""Freeze token embeddings and positional embeddings for bart, just token embeddings for t5."""
......@@ -134,27 +134,25 @@ class SummarizationModule(BaseTransformer):
def _step(self, batch: dict) -> Tuple:
pad_token_id = self.tokenizer.pad_token_id
source_ids, source_mask, target_ids = batch["input_ids"], batch["attention_mask"], batch["decoder_input_ids"]
src_ids, src_mask = batch["input_ids"], batch["attention_mask"]
tgt_ids = batch["labels"]
if isinstance(self.model, T5ForConditionalGeneration):
decoder_input_ids = self.model._shift_right(target_ids)
lm_labels = target_ids
decoder_input_ids = self.model._shift_right(tgt_ids)
else:
decoder_input_ids = target_ids[:, :-1].contiguous() # Why this line?
lm_labels = target_ids[:, 1:].clone() # why clone?
outputs = self(source_ids, attention_mask=source_mask, decoder_input_ids=decoder_input_ids, use_cache=False)
decoder_input_ids = shift_tokens_right(tgt_ids, pad_token_id)
outputs = self(src_ids, attention_mask=src_mask, decoder_input_ids=decoder_input_ids, use_cache=False)
lm_logits = outputs[0]
if self.hparams.label_smoothing == 0:
# Same behavior as modeling_bart.py
# Same behavior as modeling_bart.py, besides ignoring pad_token_id
loss_fct = torch.nn.CrossEntropyLoss(ignore_index=pad_token_id)
lm_logits = outputs[0]
assert lm_logits.shape[-1] == self.model.config.vocab_size
loss = loss_fct(lm_logits.view(-1, lm_logits.shape[-1]), lm_labels.view(-1))
loss = loss_fct(lm_logits.view(-1, lm_logits.shape[-1]), tgt_ids.view(-1))
else:
lprobs = torch.nn.functional.log_softmax(outputs[0], dim=-1)
lprobs = torch.nn.functional.log_softmax(lm_logits, dim=-1)
loss, nll_loss = label_smoothed_nll_loss(
lprobs, lm_labels, self.hparams.label_smoothing, ignore_index=pad_token_id
lprobs, tgt_ids, self.hparams.label_smoothing, ignore_index=pad_token_id
)
return (loss,)
......@@ -167,7 +165,7 @@ class SummarizationModule(BaseTransformer):
logs = {name: loss for name, loss in zip(self.loss_names, loss_tensors)}
# tokens per batch
logs["tpb"] = batch["input_ids"].ne(self.pad).sum() + batch["decoder_input_ids"].ne(self.pad).sum()
logs["tpb"] = batch["input_ids"].ne(self.pad).sum() + batch["labels"].ne(self.pad).sum()
return {"loss": loss_tensors[0], "log": logs}
def validation_step(self, batch, batch_idx) -> Dict:
......@@ -204,7 +202,7 @@ class SummarizationModule(BaseTransformer):
)
gen_time = (time.time() - t0) / batch["input_ids"].shape[0]
preds: List[str] = self.ids_to_clean_text(generated_ids)
target: List[str] = self.ids_to_clean_text(batch["decoder_input_ids"])
target: List[str] = self.ids_to_clean_text(batch["labels"])
loss_tensors = self._step(batch)
base_metrics = {name: loss for name, loss in zip(self.loss_names, loss_tensors)}
rouge: Dict = self.calc_generative_metrics(preds, target)
......
......@@ -132,4 +132,6 @@ def run_generate():
if __name__ == "__main__":
# Usage for MT:
# python run_eval.py MODEL_NAME $DATA_DIR/test.source $save_dir/test_translations.txt --reference_path $DATA_DIR/test.target --score_path $save_dir/test_bleu.json --task translation $@
run_generate()
......@@ -10,18 +10,18 @@ from unittest.mock import patch
import pytest
import pytorch_lightning as pl
import torch
from pytest import param
from torch.utils.data import DataLoader
import lightning_base
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from transformers.modeling_bart import shift_tokens_right
from transformers.testing_utils import CaptureStderr, CaptureStdout, require_multigpu
from .distillation import distill_main, evaluate_checkpoint
from .finetune import SummarizationModule, main
from .pack_dataset import pack_data_dir
from .run_eval import generate_summaries_or_translations, run_generate
from .utils import Seq2SeqDataset, TranslationDataset, label_smoothed_nll_loss, lmap, load_json
from .utils import LegacySeq2SeqDataset, Seq2SeqDataset, label_smoothed_nll_loss, lmap, load_json
logging.basicConfig(level=logging.DEBUG)
......@@ -452,18 +452,27 @@ def test_pack_dataset():
assert orig_paths == new_paths
@pytest.mark.parametrize(["tok_name"], [pytest.param(MBART_TINY), pytest.param(MARIAN_TINY)])
def test_mbart_dataset_truncation(tok_name):
@pytest.mark.parametrize(
["tok_name"],
[
pytest.param(MBART_TINY),
pytest.param(MARIAN_TINY),
pytest.param(T5_TINY),
pytest.param(BART_TINY),
pytest.param("google/pegasus-xsum"),
],
)
def test_seq2seq_dataset_truncation(tok_name):
tokenizer = AutoTokenizer.from_pretrained(tok_name)
tmp_dir = make_test_data_dir()
max_len_source = max(len(tokenizer.encode(a)) for a in ARTICLES)
max_len_target = max(len(tokenizer.encode(a)) for a in SUMMARIES)
max_src_len = 4
max_tgt_len = 8
assert max_len_target > max_src_len # Truncated
assert max_len_source > max_src_len
src_lang, tgt_lang = "ro_RO", "de_DE" # NOT WHAT IT WAS TRAINED ON
train_dataset = TranslationDataset(
assert max_len_target > max_src_len # Will be truncated
assert max_len_source > max_src_len # Will be truncated
src_lang, tgt_lang = "ro_RO", "de_DE" # ignored for all but mbart, but never causes error.
train_dataset = Seq2SeqDataset(
tokenizer,
data_dir=tmp_dir,
type_path="train",
......@@ -479,10 +488,11 @@ def test_mbart_dataset_truncation(tok_name):
# show that articles were trimmed.
assert batch["input_ids"].shape[1] == max_src_len
# show that targets are the same len
assert batch["decoder_input_ids"].shape[1] == max_tgt_len
if tok_name == MARIAN_TINY:
assert batch["labels"].shape[1] == max_tgt_len
if tok_name != MBART_TINY:
continue
# check language codes in correct place
batch["decoder_input_ids"] = shift_tokens_right(batch["labels"], tokenizer.pad_token_id)
assert batch["decoder_input_ids"][0, 0].item() == tokenizer.lang_code_to_id[tgt_lang]
assert batch["decoder_input_ids"][0, -1].item() == tokenizer.eos_token_id
assert batch["input_ids"][0, -2].item() == tokenizer.eos_token_id
......@@ -491,14 +501,14 @@ def test_mbart_dataset_truncation(tok_name):
break # No need to test every batch
@pytest.mark.parametrize(["tok"], [pytest.param(T5_TINY), pytest.param(BART_TINY), param(MARIAN_TINY)])
def test_summarization_dataset_truncation(tok):
@pytest.mark.parametrize(["tok"], [pytest.param(BART_TINY), pytest.param("bert-base-cased")])
def test_legacy_dataset_truncation(tok):
tokenizer = AutoTokenizer.from_pretrained(tok)
tmp_dir = make_test_data_dir()
max_len_source = max(len(tokenizer.encode(a)) for a in ARTICLES)
max_len_target = max(len(tokenizer.encode(a)) for a in SUMMARIES)
trunc_target = 4
train_dataset = Seq2SeqDataset(
train_dataset = LegacySeq2SeqDataset(
tokenizer,
data_dir=tmp_dir,
type_path="train",
......@@ -512,6 +522,6 @@ def test_summarization_dataset_truncation(tok):
assert batch["input_ids"].shape[1] == max_len_source
assert 20 >= batch["input_ids"].shape[1] # trimmed significantly
# show that targets were truncated
assert batch["decoder_input_ids"].shape[1] == trunc_target # Truncated
assert batch["labels"].shape[1] == trunc_target # Truncated
assert max_len_target > trunc_target # Truncated
break # No need to test every batch
......@@ -3,7 +3,6 @@ import json
import linecache
import os
import pickle
import warnings
from logging import getLogger
from pathlib import Path
from typing import Callable, Dict, Iterable, List
......@@ -41,6 +40,7 @@ def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=-100):
def encode_line(tokenizer, line, max_length, pad_to_max_length=True, return_tensors="pt"):
"""Only used by LegacyDataset"""
extra_kw = {"add_prefix_space": True} if isinstance(tokenizer, BartTokenizer) else {}
return tokenizer(
[line],
......@@ -75,7 +75,7 @@ def trim_batch(
return (input_ids[:, keep_column_mask], attention_mask[:, keep_column_mask])
class Seq2SeqDataset(Dataset):
class AbstractSeq2SeqDataset(Dataset):
def __init__(
self,
tokenizer,
......@@ -102,11 +102,28 @@ class Seq2SeqDataset(Dataset):
self.pad_token_id = self.tokenizer.pad_token_id
self.src_lang = src_lang
self.tgt_lang = tgt_lang
self.add_prefix_space = isinstance(self.tokenizer, BartTokenizer)
def __len__(self):
return len(self.src_lens)
@staticmethod
def get_char_lens(data_file):
return [len(x) for x in Path(data_file).open().readlines()]
def make_sortish_sampler(self, batch_size):
return SortishSampler(self.src_lens, batch_size)
def __getitem__(self, item):
raise NotImplementedError("You must implement this")
def collate_fn(self, batch):
raise NotImplementedError("You must implement this")
class LegacySeq2SeqDataset(AbstractSeq2SeqDataset):
def __getitem__(self, index) -> Dict[str, torch.Tensor]:
"""Call tokenizer on src and tgt_lines"""
index = index + 1 # linecache starts at 1
source_line = self.prefix + linecache.getline(str(self.src_file), index).rstrip("\n")
tgt_line = linecache.getline(str(self.tgt_file), index).rstrip("\n")
......@@ -121,42 +138,27 @@ class Seq2SeqDataset(Dataset):
return {
"input_ids": source_ids,
"attention_mask": src_mask,
"decoder_input_ids": target_ids,
"labels": target_ids,
}
@staticmethod
def get_char_lens(data_file):
return [len(x) for x in Path(data_file).open().readlines()]
def collate_fn(self, batch) -> Dict[str, torch.Tensor]:
input_ids = torch.stack([x["input_ids"] for x in batch])
masks = torch.stack([x["attention_mask"] for x in batch])
target_ids = torch.stack([x["decoder_input_ids"] for x in batch])
target_ids = torch.stack([x["labels"] for x in batch])
pad_token_id = self.pad_token_id
y = trim_batch(target_ids, pad_token_id)
source_ids, source_mask = trim_batch(input_ids, pad_token_id, attention_mask=masks)
batch = {
"input_ids": source_ids,
"attention_mask": source_mask,
"decoder_input_ids": y,
"labels": y,
}
return batch
def make_sortish_sampler(self, batch_size):
return SortishSampler(self.src_lens, batch_size)
class TranslationDataset(Seq2SeqDataset):
class Seq2SeqDataset(AbstractSeq2SeqDataset):
"""A dataset that calls prepare_seq2seq_batch."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if self.max_source_length != self.max_target_length:
warnings.warn(
f"Mbart is using sequence lengths {self.max_source_length}, {self.max_target_length}. "
f"Imbalanced sequence lengths may be undesired for translation tasks"
)
def __getitem__(self, index) -> Dict[str, str]:
index = index + 1 # linecache starts at 1
source_line = self.prefix + linecache.getline(str(self.src_file), index).rstrip("\n")
......@@ -169,6 +171,7 @@ class TranslationDataset(Seq2SeqDataset):
}
def collate_fn(self, batch) -> Dict[str, torch.Tensor]:
"""Call prepare_seq2seq_batch."""
batch_encoding = self.tokenizer.prepare_seq2seq_batch(
[x["src_texts"] for x in batch],
src_lang=self.src_lang,
......@@ -176,6 +179,8 @@ class TranslationDataset(Seq2SeqDataset):
tgt_lang=self.tgt_lang,
max_length=self.max_source_length,
max_target_length=self.max_target_length,
return_tensors="pt",
add_prefix_space=self.add_prefix_space,
)
return batch_encoding.data
......@@ -276,7 +281,11 @@ def calculate_rouge(output_lns: List[str], reference_lns: List[str], use_stemmer
return {k: round(v.mid.fmeasure * 100, 4) for k, v in result.items()}
# Utilities for freezing parameters and checking whether they are frozen
def freeze_params(model: nn.Module):
"""Set requires_grad=False for each of model.parameters()"""
for par in model.parameters():
par.requires_grad = False
......
......@@ -151,6 +151,9 @@ def _prepare_bart_decoder_inputs(
decoder_padding_mask = make_padding_mask(decoder_input_ids, pad_token_id)
else:
decoder_padding_mask = invert_mask(decoder_padding_mask)
if decoder_padding_mask is not None and decoder_padding_mask.shape[1] > 1:
# never mask leading token, even if it is pad
decoder_padding_mask[:, 0] = decoder_padding_mask[:, 1]
causal_mask = torch.triu(fill_with_neg_inf(torch.zeros(tgt_len, tgt_len)), 1).to(
dtype=causal_mask_dtype, device=decoder_input_ids.device
)
......
......@@ -636,7 +636,7 @@ class T5PreTrainedModel(PreTrainedModel):
# replace possible -100 values in labels by `pad_token_id`
shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
assert torch.all(shifted_input_ids >= 0).item(), "Verify that `labels` has only positive values and -100"
assert torch.all(shifted_input_ids >= 0).item(), "Verify that `shifted_input_ids` has only positive values"
return shifted_input_ids
......
......@@ -33,6 +33,7 @@ _all_bart_models = [
"facebook/bart-large-cnn",
"facebook/bart-large-xsum",
"yjernite/bart_eli5",
# This is not exhaustive: see https://huggingface.co/models?filter=bart
]
......@@ -117,6 +118,8 @@ class BartTokenizer(RobertaTokenizer):
The full set of keys ``[input_ids, attention_mask, decoder_input_ids, decoder_attention_mask]``,
will only be returned if tgt_texts is passed. Otherwise, input_ids, attention_mask will be the only keys.
"""
kwargs.pop("src_lang", None)
kwargs.pop("tgt_lang", None)
if max_length is None:
max_length = self.model_max_length
model_inputs: BatchEncoding = self(
......@@ -133,7 +136,7 @@ class BartTokenizer(RobertaTokenizer):
# Process tgt_texts
if max_target_length is None:
max_target_length = max_length
decoder_inputs: BatchEncoding = self(
labels = self(
tgt_texts,
add_special_tokens=True,
return_tensors=return_tensors,
......@@ -141,10 +144,8 @@ class BartTokenizer(RobertaTokenizer):
max_length=max_target_length,
truncation=truncation,
**kwargs,
)
for k, v in decoder_inputs.items():
model_inputs[f"decoder_{k}"] = v
)["input_ids"]
model_inputs["labels"] = labels
return model_inputs
......@@ -245,7 +246,7 @@ class BartTokenizerFast(RobertaTokenizerFast):
# Process tgt_texts
if max_target_length is None:
max_target_length = max_length
decoder_inputs: BatchEncoding = self(
labels = self(
tgt_texts,
add_special_tokens=True,
return_tensors=return_tensors,
......@@ -253,8 +254,6 @@ class BartTokenizerFast(RobertaTokenizerFast):
max_length=max_target_length,
truncation=truncation,
**kwargs,
)
for k, v in decoder_inputs.items():
model_inputs[f"decoder_{k}"] = v
)["input_ids"]
model_inputs["labels"] = labels
return model_inputs
......@@ -160,9 +160,7 @@ class MarianTokenizer(PreTrainedTokenizer):
tokenizer_kwargs["max_length"] = max_target_length
self.current_spm = self.spm_target
decoder_inputs: BatchEncoding = self(tgt_texts, **tokenizer_kwargs)
for k, v in decoder_inputs.items():
model_inputs[f"decoder_{k}"] = v
model_inputs["labels"] = self(tgt_texts, **tokenizer_kwargs)["input_ids"]
self.current_spm = self.spm_source
return model_inputs
......
......@@ -98,32 +98,6 @@ class MBartTokenizer(XLMRobertaTokenizer):
self._additional_special_tokens = list(self.lang_code_to_id.keys())
self.set_src_lang_special_tokens(kwargs.get("src_lang", "en_XX"))
def build_inputs_with_special_tokens(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
) -> List[int]:
"""
Build model inputs from a sequence or a pair of sequence for sequence classification tasks
by concatenating and adding special tokens. The special tokens depend on calling set_lang.
An MBART sequence has the following format, where ``X`` represents the sequence:
- ``input_ids`` (for encoder) ``X [eos, src_lang_code]``
- ``decoder_input_ids``: (for decoder) ``[tgt_lang_code] X [eos]``
BOS is never used.
Pairs of sequences are not the expected use case, but they will be handled without a separator.
Args:
token_ids_0 (:obj:`List[int]`):
List of IDs to which the special tokens will be added
token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
Optional second list of IDs for sequence pairs.
Returns:
:obj:`List[int]`: list of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens.
"""
if token_ids_1 is None:
return self.prefix_tokens + token_ids_0 + self.suffix_tokens
# We don't expect to process pairs, but leave the pair logic for API consistency
return self.prefix_tokens + token_ids_0 + token_ids_1 + self.suffix_tokens
def get_special_tokens_mask(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
) -> List[int]:
......@@ -156,6 +130,32 @@ class MBartTokenizer(XLMRobertaTokenizer):
return prefix_ones + ([0] * len(token_ids_0)) + suffix_ones
return prefix_ones + ([0] * len(token_ids_0)) + ([0] * len(token_ids_1)) + suffix_ones
def build_inputs_with_special_tokens(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
) -> List[int]:
"""
Build model inputs from a sequence or a pair of sequence for sequence classification tasks
by concatenating and adding special tokens. The special tokens depend on calling set_lang.
An MBART sequence has the following format, where ``X`` represents the sequence:
- ``input_ids`` (for encoder) ``X [eos, src_lang_code]``
- ``decoder_input_ids``: (for decoder) ``[tgt_lang_code] X [eos]``
BOS is never used.
Pairs of sequences are not the expected use case, but they will be handled without a separator.
Args:
token_ids_0 (:obj:`List[int]`):
List of IDs to which the special tokens will be added
token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
Optional second list of IDs for sequence pairs.
Returns:
:obj:`List[int]`: list of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens.
"""
if token_ids_1 is None:
return self.prefix_tokens + token_ids_0 + self.suffix_tokens
# We don't expect to process pairs, but leave the pair logic for API consistency
return self.prefix_tokens + token_ids_0 + token_ids_1 + self.suffix_tokens
@add_start_docstrings_to_callable(PREPARE_SEQ2SEQ_BATCH_DOCSTRING)
def prepare_seq2seq_batch(
self,
......@@ -251,7 +251,8 @@ class MBartTokenizer(XLMRobertaTokenizer):
if max_target_length is None:
max_target_length = max_length
self.set_tgt_lang_special_tokens(tgt_lang)
decoder_inputs: BatchEncoding = self(
labels = self(
tgt_texts,
add_special_tokens=True,
return_tensors=return_tensors,
......@@ -259,10 +260,8 @@ class MBartTokenizer(XLMRobertaTokenizer):
max_length=max_target_length,
truncation=True,
**kwargs,
)
for k, v in decoder_inputs.items():
model_inputs[f"decoder_{k}"] = v
)["input_ids"]
model_inputs["labels"] = labels
self.set_src_lang_special_tokens(src_lang) # sets to src_lang
return model_inputs
......@@ -275,5 +274,5 @@ class MBartTokenizer(XLMRobertaTokenizer):
def set_tgt_lang_special_tokens(self, lang: str) -> None:
"""Reset the special tokens to the target language setting. Prefix [tgt_lang_code], suffix =[eos]."""
self.cur_lang_code = self.lang_code_to_id[lang]
self.prefix_tokens = [self.cur_lang_code]
self.suffix_tokens = [self.eos_token_id]
self.prefix_tokens = []
self.suffix_tokens = [self.eos_token_id, self.cur_lang_code]
......@@ -114,6 +114,7 @@ class PegasusTokenizer(ReformerTokenizer):
return_tensors: str = "pt",
truncation=True,
padding="longest",
**unused,
) -> BatchEncoding:
"""
Prepare model inputs for summarization or translation.
......@@ -133,7 +134,9 @@ class PegasusTokenizer(ReformerTokenizer):
return model_inputs
if max_target_length is not None:
tokenizer_kwargs["max_length"] = max_target_length
decoder_inputs: BatchEncoding = self(tgt_texts, **tokenizer_kwargs)
for k, v in decoder_inputs.items():
model_inputs[f"decoder_{k}"] = v
# TODO(@sshleifer): maybe tgt_texts = [self.pad_token + t for t in tgt_texts] # add decoder_start_token_id
labels: BatchEncoding = self(tgt_texts, **tokenizer_kwargs)["input_ids"]
model_inputs["labels"] = labels
# for k, v in decoder_inputs.items():
# model_inputs[f"decoder_{k}"] = v
return model_inputs
......@@ -346,7 +346,7 @@ class T5Tokenizer(PreTrainedTokenizer):
if max_length is None:
max_length = self.max_len
self.prefix_tokens = []
model_inputs: BatchEncoding = self(
model_inputs = self(
src_texts,
add_special_tokens=True,
return_tensors=return_tensors,
......@@ -362,7 +362,7 @@ class T5Tokenizer(PreTrainedTokenizer):
max_target_length = max_length
# set prefix_tokens for target text
self.prefix_tokens = [self.pad_token_id]
decoder_inputs: BatchEncoding = self(
labels_and_decoder_mask = self(
tgt_texts,
add_special_tokens=True,
return_tensors=return_tensors,
......@@ -371,8 +371,7 @@ class T5Tokenizer(PreTrainedTokenizer):
truncation=truncation,
**kwargs,
)
for k, v in decoder_inputs.items():
model_inputs[f"decoder_{k}"] = v
model_inputs["labels"] = labels_and_decoder_mask["input_ids"]
model_inputs["decoder_attention_mask"] = labels_and_decoder_mask["attention_mask"]
self.prefix_tokens = []
return model_inputs
......@@ -18,7 +18,7 @@ import unittest
import timeout_decorator # noqa
from transformers import BatchEncoding, is_torch_available
from transformers import is_torch_available
from transformers.file_utils import cached_property
from transformers.testing_utils import require_torch, slow, torch_device
......@@ -496,7 +496,7 @@ class BartModelIntegrationTests(unittest.TestCase):
def test_xsum_summarization_same_as_fairseq(self):
model = BartForConditionalGeneration.from_pretrained("facebook/bart-large-xsum").to(torch_device)
self.assertFalse(model.config.is_valid_mbart())
tok = BartTokenizer.from_pretrained("facebook/bart-large")
tok = self.default_tokenizer
EXPECTED_SUMMARY = "California's largest power company has begun shutting off electricity to thousands of customers in the state."
dct = tok.batch_encode_plus(
......@@ -585,84 +585,6 @@ class BartModelIntegrationTests(unittest.TestCase):
# TODO(SS): run fairseq again with num_beams=2, min_len=20.
# TODO(SS): add test case that hits max_length
def test_prepare_seq2seq_batch(self):
tokenizers = [self.default_tokenizer, self.default_tokenizer_fast]
src_text = ["A long paragraph for summrization.", "Another paragraph for summrization."]
tgt_text = [
"Summary of the text.",
"Another summary.",
]
expected_src_tokens = [0, 250, 251, 17818, 13, 32933, 21645, 1258, 4, 2]
for tokenizer in tokenizers:
batch = tokenizer.prepare_seq2seq_batch(
src_text, tgt_texts=tgt_text, max_length=len(expected_src_tokens), return_tensors="pt"
)
self.assertIsInstance(batch, BatchEncoding)
self.assertEqual((2, 10), batch.input_ids.shape)
self.assertEqual((2, 10), batch.attention_mask.shape)
result = batch.input_ids.tolist()[0]
self.assertListEqual(expected_src_tokens, result)
# Test that special tokens are reset
def test_empty_target_text(self):
tokenizers = [self.default_tokenizer, self.default_tokenizer_fast]
src_text = ["A long paragraph for summrization.", "Another paragraph for summrization."]
for tokenizer in tokenizers:
batch = tokenizer.prepare_seq2seq_batch(src_text, return_tensors="pt")
# check if input_ids are returned and no decoder_input_ids
self.assertIn("input_ids", batch)
self.assertIn("attention_mask", batch)
self.assertNotIn("decoder_input_ids", batch)
self.assertNotIn("decoder_attention_mask", batch)
def test_max_target_length(self):
tokenizers = [self.default_tokenizer, self.default_tokenizer_fast]
src_text = ["A long paragraph for summrization.", "Another paragraph for summrization."]
tgt_text = [
"Summary of the text.",
"Another summary.",
]
for tokenizer in tokenizers:
batch = tokenizer.prepare_seq2seq_batch(
src_text, tgt_texts=tgt_text, max_target_length=32, padding="max_length", return_tensors="pt"
)
self.assertEqual(32, batch["decoder_input_ids"].shape[1])
self.assertEqual(32, batch["decoder_attention_mask"].shape[1])
# test None max_target_length
batch = tokenizer.prepare_seq2seq_batch(
src_text, tgt_texts=tgt_text, max_length=32, padding="max_length", return_tensors="pt"
)
self.assertEqual(32, batch["decoder_input_ids"].shape[1])
self.assertEqual(32, batch["decoder_attention_mask"].shape[1])
def test_outputs_not_longer_than_maxlen(self):
tokenizers = [self.default_tokenizer, self.default_tokenizer_fast]
for tokenizer in tokenizers:
batch = tokenizer.prepare_seq2seq_batch(
["I am a small frog" * 1024, "I am a small frog"], return_tensors="pt"
)
self.assertIsInstance(batch, BatchEncoding)
self.assertEqual(batch.input_ids.shape, (2, 1024))
def test_special_tokens(self):
tokenizers = [self.default_tokenizer, self.default_tokenizer_fast]
src_text = ["A long paragraph for summrization."]
tgt_text = [
"Summary of the text.",
]
for tokenizer in tokenizers:
batch = tokenizer.prepare_seq2seq_batch(src_text, tgt_texts=tgt_text, return_tensors="pt")
input_ids = batch["input_ids"]
decoder_input_ids = batch["decoder_input_ids"]
self.assertTrue((input_ids[:, 0] == tokenizer.bos_token_id).all().item())
self.assertTrue((decoder_input_ids[:, 0] == tokenizer.bos_token_id).all().item())
self.assertTrue((input_ids[:, -1] == tokenizer.eos_token_id).all().item())
self.assertTrue((decoder_input_ids[:, -1] == tokenizer.eos_token_id).all().item())
@require_torch
class TestSinusoidalPositionalEmbeddings(unittest.TestCase):
......
import json
import os
import unittest
from transformers import BartTokenizer, BartTokenizerFast, BatchEncoding
from transformers.file_utils import cached_property
from transformers.testing_utils import require_torch
from transformers.tokenization_roberta import VOCAB_FILES_NAMES
from .test_tokenization_common import TokenizerTesterMixin
class TestTokenizationBart(TokenizerTesterMixin, unittest.TestCase):
tokenizer_class = BartTokenizer
def setUp(self):
super().setUp()
vocab = [
"l",
"o",
"w",
"e",
"r",
"s",
"t",
"i",
"d",
"n",
"\u0120",
"\u0120l",
"\u0120n",
"\u0120lo",
"\u0120low",
"er",
"\u0120lowest",
"\u0120newer",
"\u0120wider",
"<unk>",
]
vocab_tokens = dict(zip(vocab, range(len(vocab))))
merges = ["#version: 0.2", "\u0120 l", "\u0120l o", "\u0120lo w", "e r", ""]
self.special_tokens_map = {"unk_token": "<unk>"}
self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES["vocab_file"])
self.merges_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES["merges_file"])
with open(self.vocab_file, "w", encoding="utf-8") as fp:
fp.write(json.dumps(vocab_tokens) + "\n")
with open(self.merges_file, "w", encoding="utf-8") as fp:
fp.write("\n".join(merges))
def get_tokenizer(self, **kwargs):
kwargs.update(self.special_tokens_map)
return self.tokenizer_class.from_pretrained(self.tmpdirname, **kwargs)
def get_rust_tokenizer(self, **kwargs):
kwargs.update(self.special_tokens_map)
return BartTokenizerFast.from_pretrained(self.tmpdirname, **kwargs)
def get_input_output_texts(self, tokenizer):
return "lower newer", "lower newer"
@cached_property
def default_tokenizer(self):
return BartTokenizer.from_pretrained("facebook/bart-large")
@cached_property
def default_tokenizer_fast(self):
return BartTokenizerFast.from_pretrained("facebook/bart-large")
@require_torch
def test_prepare_seq2seq_batch(self):
src_text = ["A long paragraph for summrization.", "Another paragraph for summrization."]
tgt_text = [
"Summary of the text.",
"Another summary.",
]
expected_src_tokens = [0, 250, 251, 17818, 13, 32933, 21645, 1258, 4, 2]
for tokenizer in [self.default_tokenizer, self.default_tokenizer_fast]:
batch = tokenizer.prepare_seq2seq_batch(
src_text, tgt_texts=tgt_text, max_length=len(expected_src_tokens), return_tensors="pt"
)
self.assertIsInstance(batch, BatchEncoding)
self.assertEqual((2, 10), batch.input_ids.shape)
self.assertEqual((2, 10), batch.attention_mask.shape)
result = batch.input_ids.tolist()[0]
self.assertListEqual(expected_src_tokens, result)
# Test that special tokens are reset
# Test Prepare Seq
@require_torch
def test_seq2seq_batch_empty_target_text(self):
src_text = ["A long paragraph for summrization.", "Another paragraph for summrization."]
for tokenizer in [self.default_tokenizer, self.default_tokenizer_fast]:
batch = tokenizer.prepare_seq2seq_batch(src_text, return_tensors="pt")
# check if input_ids are returned and no labels
self.assertIn("input_ids", batch)
self.assertIn("attention_mask", batch)
self.assertNotIn("labels", batch)
self.assertNotIn("decoder_attention_mask", batch)
@require_torch
def test_seq2seq_batch_max_target_length(self):
src_text = ["A long paragraph for summrization.", "Another paragraph for summrization."]
tgt_text = [
"Summary of the text.",
"Another summary.",
]
for tokenizer in [self.default_tokenizer, self.default_tokenizer_fast]:
batch = tokenizer.prepare_seq2seq_batch(
src_text, tgt_texts=tgt_text, max_target_length=32, padding="max_length", return_tensors="pt"
)
self.assertEqual(32, batch["labels"].shape[1])
# test None max_target_length
batch = tokenizer.prepare_seq2seq_batch(
src_text, tgt_texts=tgt_text, max_length=32, padding="max_length", return_tensors="pt"
)
self.assertEqual(32, batch["labels"].shape[1])
@require_torch
def test_seq2seq_batch_not_longer_than_maxlen(self):
for tokenizer in [self.default_tokenizer, self.default_tokenizer_fast]:
batch = tokenizer.prepare_seq2seq_batch(
["I am a small frog" * 1024, "I am a small frog"], return_tensors="pt"
)
self.assertIsInstance(batch, BatchEncoding)
self.assertEqual(batch.input_ids.shape, (2, 1024))
@require_torch
def test_special_tokens(self):
src_text = ["A long paragraph for summrization."]
tgt_text = [
"Summary of the text.",
]
for tokenizer in [self.default_tokenizer, self.default_tokenizer_fast]:
batch = tokenizer.prepare_seq2seq_batch(src_text, tgt_texts=tgt_text, return_tensors="pt")
input_ids = batch["input_ids"]
labels = batch["labels"]
self.assertTrue((input_ids[:, 0] == tokenizer.bos_token_id).all().item())
self.assertTrue((labels[:, 0] == tokenizer.bos_token_id).all().item())
self.assertTrue((input_ids[:, -1] == tokenizer.eos_token_id).all().item())
self.assertTrue((labels[:, -1] == tokenizer.eos_token_id).all().item())
......@@ -1555,14 +1555,19 @@ class TokenizerTesterMixin:
"vor face decât să înrăutăţească violenţele şi mizeria pentru milioane de oameni.",
]
batch = tokenizer.prepare_seq2seq_batch(
src_texts=src_text, tgt_texts=tgt_text, max_length=3, max_target_length=10, return_tensors="pt"
src_texts=src_text,
tgt_texts=tgt_text,
max_length=3,
max_target_length=10,
return_tensors="pt",
src_lang="en_XX", # this should be ignored (for all but mbart) but not cause an error
)
self.assertEqual(batch.input_ids.shape[1], 3)
self.assertEqual(batch.decoder_input_ids.shape[1], 10)
self.assertEqual(batch.labels.shape[1], 10)
# max_target_length will default to max_length if not specified
batch = tokenizer.prepare_seq2seq_batch(src_text, tgt_texts=tgt_text, max_length=3)
self.assertEqual(batch.input_ids.shape[1], 3)
self.assertEqual(batch.decoder_input_ids.shape[1], 3)
self.assertEqual(batch.labels.shape[1], 3)
batch_encoder_only = tokenizer.prepare_seq2seq_batch(
src_texts=src_text, max_length=3, max_target_length=10, return_tensors="pt"
......
import tempfile
import unittest
from transformers import AutoTokenizer, BatchEncoding, MBartTokenizer
from transformers import AutoTokenizer, BatchEncoding, MBartTokenizer, is_torch_available
from transformers.testing_utils import require_torch
from .test_tokenization_common import TokenizerTesterMixin
from .test_tokenization_xlm_roberta import SAMPLE_VOCAB, SPIECE_UNDERLINE
if is_torch_available():
from transformers.modeling_bart import shift_tokens_right
EN_CODE = 250004
RO_CODE = 250020
......@@ -123,35 +126,6 @@ class MBartEnroIntegrationTest(unittest.TestCase):
self.assertEqual(self.tokenizer.fairseq_tokens_to_ids["en_EN"], 250004)
self.assertEqual(self.tokenizer.fairseq_tokens_to_ids["ro_RO"], 250020)
def test_enro_tokenizer_prepare_seq2seq_batch(self):
batch = self.tokenizer.prepare_seq2seq_batch(
self.src_text,
tgt_texts=self.tgt_text,
max_length=len(self.expected_src_tokens),
)
self.assertIsInstance(batch, BatchEncoding)
self.assertEqual((2, 14), batch.input_ids.shape)
self.assertEqual((2, 14), batch.attention_mask.shape)
result = batch.input_ids.tolist()[0]
self.assertListEqual(self.expected_src_tokens, result)
self.assertEqual(2, batch.decoder_input_ids[0, -1]) # EOS
# Test that special tokens are reset
self.assertEqual(self.tokenizer.prefix_tokens, [])
self.assertEqual(self.tokenizer.suffix_tokens, [self.tokenizer.eos_token_id, EN_CODE])
def test_max_target_length(self):
batch = self.tokenizer.prepare_seq2seq_batch(
self.src_text, tgt_texts=self.tgt_text, max_length=3, max_target_length=10
)
self.assertEqual(batch.input_ids.shape[1], 3)
self.assertEqual(batch.decoder_input_ids.shape[1], 10)
# max_target_length will default to max_length if not specified
batch = self.tokenizer.prepare_seq2seq_batch(self.src_text, tgt_texts=self.tgt_text, max_length=3)
self.assertEqual(batch.input_ids.shape[1], 3)
self.assertEqual(batch.decoder_input_ids.shape[1], 3)
def test_enro_tokenizer_batch_encode_plus(self):
ids = self.tokenizer.batch_encode_plus(self.src_text).input_ids[0]
self.assertListEqual(self.expected_src_tokens, ids)
......@@ -169,7 +143,9 @@ class MBartEnroIntegrationTest(unittest.TestCase):
assert isinstance(src_text[0], str)
desired_max_length = 10
ids = self.tokenizer.prepare_seq2seq_batch(
src_text, return_tensors=None, max_length=desired_max_length
src_text,
return_tensors=None,
max_length=desired_max_length,
).input_ids[0]
self.assertEqual(ids[-2], 2)
self.assertEqual(ids[-1], EN_CODE)
......@@ -184,3 +160,53 @@ class MBartEnroIntegrationTest(unittest.TestCase):
self.tokenizer.save_pretrained(tmpdirname)
new_tok = MBartTokenizer.from_pretrained(tmpdirname)
self.assertDictEqual(new_tok.fairseq_tokens_to_ids, original_special_tokens)
# prepare_seq2seq_batch tests below
@require_torch
def test_batch_fairseq_parity(self):
batch: BatchEncoding = self.tokenizer.prepare_seq2seq_batch(
self.src_text, tgt_texts=self.tgt_text, return_tensors="pt"
)
batch["decoder_input_ids"] = shift_tokens_right(batch.labels, self.tokenizer.pad_token_id)
for k in batch:
batch[k] = batch[k].tolist()
# batch = {k: v.tolist() for k,v in batch.items()}
# fairseq batch: https://gist.github.com/sshleifer/cba08bc2109361a74ac3760a7e30e4f4
# batch.decoder_inputs_ids[0][0] ==
assert batch.input_ids[1][-2:] == [2, EN_CODE]
assert batch.decoder_input_ids[1][0] == RO_CODE
assert batch.decoder_input_ids[1][-1] == 2
assert batch.labels[1][-2:] == [2, RO_CODE]
@require_torch
def test_enro_tokenizer_prepare_seq2seq_batch(self):
batch = self.tokenizer.prepare_seq2seq_batch(
self.src_text,
tgt_texts=self.tgt_text,
max_length=len(self.expected_src_tokens),
)
batch["decoder_input_ids"] = shift_tokens_right(batch.labels, self.tokenizer.pad_token_id)
self.assertIsInstance(batch, BatchEncoding)
self.assertEqual((2, 14), batch.input_ids.shape)
self.assertEqual((2, 14), batch.attention_mask.shape)
result = batch.input_ids.tolist()[0]
self.assertListEqual(self.expected_src_tokens, result)
self.assertEqual(2, batch.decoder_input_ids[0, -1]) # EOS
# Test that special tokens are reset
self.assertEqual(self.tokenizer.prefix_tokens, [])
self.assertEqual(self.tokenizer.suffix_tokens, [self.tokenizer.eos_token_id, EN_CODE])
def test_seq2seq_max_target_length(self):
batch = self.tokenizer.prepare_seq2seq_batch(
self.src_text, tgt_texts=self.tgt_text, max_length=3, max_target_length=10
)
batch["decoder_input_ids"] = shift_tokens_right(batch.labels, self.tokenizer.pad_token_id)
self.assertEqual(batch.input_ids.shape[1], 3)
self.assertEqual(batch.decoder_input_ids.shape[1], 10)
# max_target_length will default to max_length if not specified
batch = self.tokenizer.prepare_seq2seq_batch(self.src_text, tgt_texts=self.tgt_text, max_length=3)
batch["decoder_input_ids"] = shift_tokens_right(batch.labels, self.tokenizer.pad_token_id)
self.assertEqual(batch.input_ids.shape[1], 3)
self.assertEqual(batch.decoder_input_ids.shape[1], 3)
......@@ -63,7 +63,6 @@ class PegasusTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
batch = self.pegasus_large_tokenizer.prepare_seq2seq_batch(src_texts, tgt_texts=tgt_texts, max_target_length=5)
assert batch.input_ids.shape == (2, 1024)
assert batch.attention_mask.shape == (2, 1024)
assert "decoder_input_ids" in batch # because tgt_texts was specified
assert batch.decoder_input_ids.shape == (2, 5)
assert batch.decoder_attention_mask.shape == (2, 5)
assert len(batch) == 4 # no extra keys
assert "labels" in batch # because tgt_texts was specified
assert batch.labels.shape == (2, 5)
assert len(batch) == 3 # input_ids, attention_mask, labels. Other things make by BartModel
......@@ -66,7 +66,7 @@ class RobertaTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
def get_tokenizer(self, **kwargs):
kwargs.update(self.special_tokens_map)
return RobertaTokenizer.from_pretrained(self.tmpdirname, **kwargs)
return self.tokenizer_class.from_pretrained(self.tmpdirname, **kwargs)
def get_rust_tokenizer(self, **kwargs):
kwargs.update(self.special_tokens_map)
......@@ -78,7 +78,7 @@ class RobertaTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
return input_text, output_text
def test_full_tokenizer(self):
tokenizer = RobertaTokenizer(self.vocab_file, self.merges_file, **self.special_tokens_map)
tokenizer = self.tokenizer_class(self.vocab_file, self.merges_file, **self.special_tokens_map)
text = "lower newer"
bpe_tokens = ["l", "o", "w", "er", "\u0120", "n", "e", "w", "er"]
tokens = tokenizer.tokenize(text) # , add_prefix_space=True)
......@@ -99,7 +99,7 @@ class RobertaTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
@slow
def test_sequence_builders(self):
tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
tokenizer = self.tokenizer_class.from_pretrained("roberta-base")
text = tokenizer.encode("sequence builders", add_special_tokens=False)
text_2 = tokenizer.encode("multi-sequence build", add_special_tokens=False)
......@@ -137,7 +137,7 @@ class RobertaTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
first_char = tokenizer.convert_ids_to_tokens(encoded[1])[0]
self.assertNotEqual(first_char, space_encoding)
# Testing spaces after special tokenss
# Testing spaces after special tokens
mask = "<mask>"
tokenizer.add_special_tokens(
{"mask_token": AddedToken(mask, lstrip=True, rstrip=False)}
......
......@@ -153,7 +153,7 @@ class T5TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
def test_max_target_length(self):
tokenizer = self.t5_base_tokenizer
src_text = ["A long paragraph for summrization.", "Another paragraph for summrization."]
src_text = ["A short paragraph for summrization.", "Another short paragraph for summrization."]
tgt_text = [
"Summary of the text.",
"Another summary.",
......@@ -161,14 +161,14 @@ class T5TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
batch = tokenizer.prepare_seq2seq_batch(
src_text, tgt_texts=tgt_text, max_target_length=32, padding="max_length", return_tensors=FRAMEWORK
)
self.assertEqual(32, batch["decoder_input_ids"].shape[1])
self.assertEqual(32, batch["labels"].shape[1])
self.assertEqual(32, batch["decoder_attention_mask"].shape[1])
# test None max_target_length
batch = tokenizer.prepare_seq2seq_batch(
src_text, tgt_texts=tgt_text, max_length=32, padding="max_length", return_tensors=FRAMEWORK
)
self.assertEqual(32, batch["decoder_input_ids"].shape[1])
self.assertEqual(32, batch["labels"].shape[1])
self.assertEqual(32, batch["decoder_attention_mask"].shape[1])
def test_outputs_not_longer_than_maxlen(self):
......@@ -190,7 +190,7 @@ class T5TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
batch = tokenizer.prepare_seq2seq_batch(src_text, tgt_texts=tgt_text, return_tensors=FRAMEWORK)
src_ids = list(batch.input_ids.numpy()[0])
tgt_ids = list(batch.decoder_input_ids.numpy()[0])
tgt_ids = list(batch.labels.numpy()[0])
self.assertEqual(expected_src_tokens, src_ids)
self.assertEqual(expected_tgt_tokens, tgt_ids)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment