Unverified Commit 3c7fbf35 authored by Sam Shleifer's avatar Sam Shleifer Committed by GitHub
Browse files

MBART: support summarization tasks where max_src_len > max_tgt_len (#6003)

* MBART: support summarization tasks

* fix test

* Style

* add tokenizer test
parent 842eb456
...@@ -180,6 +180,8 @@ python run_eval.py sshleifer/distilbart-cnn-12-6 $DATA_DIR/val.source dbart_val_ ...@@ -180,6 +180,8 @@ python run_eval.py sshleifer/distilbart-cnn-12-6 $DATA_DIR/val.source dbart_val_
--task summarization \ --task summarization \
--n_obs 100 \ --n_obs 100 \
--device cuda \ --device cuda \
--max_source_length 1024 \
--max_target_length 56 \
--fp16 \ --fp16 \
--bs 32 --bs 32
``` ```
......
...@@ -105,7 +105,13 @@ class SummarizationModule(BaseTransformer): ...@@ -105,7 +105,13 @@ class SummarizationModule(BaseTransformer):
self.hparams.git_sha = get_git_info()["repo_sha"] self.hparams.git_sha = get_git_info()["repo_sha"]
self.num_workers = hparams.num_workers self.num_workers = hparams.num_workers
self.decoder_start_token_id = None self.decoder_start_token_id = None
self.dataset_class = Seq2SeqDataset if self.model.config.decoder_start_token_id is None and isinstance(self.tokenizer, MBartTokenizer):
self.decoder_start_token_id = self.tokenizer.lang_code_to_id[hparams.tgt_lang]
self.model.config.decoder_start_token_id = self.decoder_start_token_id
if isinstance(self.tokenizer, MBartTokenizer):
self.dataset_class = MBartDataset
else:
self.dataset_class = Seq2SeqDataset
def freeze_embeds(self): def freeze_embeds(self):
"""Freeze token embeddings and positional embeddings for bart, just token embeddings for t5.""" """Freeze token embeddings and positional embeddings for bart, just token embeddings for t5."""
...@@ -331,11 +337,6 @@ class TranslationModule(SummarizationModule): ...@@ -331,11 +337,6 @@ class TranslationModule(SummarizationModule):
super().__init__(hparams, **kwargs) super().__init__(hparams, **kwargs)
self.dataset_kwargs["src_lang"] = hparams.src_lang self.dataset_kwargs["src_lang"] = hparams.src_lang
self.dataset_kwargs["tgt_lang"] = hparams.tgt_lang self.dataset_kwargs["tgt_lang"] = hparams.tgt_lang
if self.model.config.decoder_start_token_id is None and isinstance(self.tokenizer, MBartTokenizer):
self.decoder_start_token_id = self.tokenizer.lang_code_to_id[hparams.tgt_lang]
self.model.config.decoder_start_token_id = self.decoder_start_token_id
if isinstance(self.tokenizer, MBartTokenizer):
self.dataset_class = MBartDataset
def calc_generative_metrics(self, preds, target) -> dict: def calc_generative_metrics(self, preds, target) -> dict:
return calculate_bleu_score(preds, target) return calculate_bleu_score(preds, target)
......
...@@ -8,6 +8,7 @@ python finetune.py \ ...@@ -8,6 +8,7 @@ python finetune.py \
--eval_batch_size=$BS \ --eval_batch_size=$BS \
--output_dir=$OUTPUT_DIR \ --output_dir=$OUTPUT_DIR \
--max_source_length=512 \ --max_source_length=512 \
--max_target_length=56 \
--val_check_interval=0.1 --n_val=200 \ --val_check_interval=0.1 --n_val=200 \
--do_train --do_predict \ --do_train --do_predict \
$@ $@
...@@ -300,14 +300,17 @@ def test_mbart_dataset_truncation(): ...@@ -300,14 +300,17 @@ def test_mbart_dataset_truncation():
tmp_dir = make_test_data_dir() tmp_dir = make_test_data_dir()
max_len_source = max(len(tokenizer.encode(a)) for a in ARTICLES) max_len_source = max(len(tokenizer.encode(a)) for a in ARTICLES)
max_len_target = max(len(tokenizer.encode(a)) for a in SUMMARIES) max_len_target = max(len(tokenizer.encode(a)) for a in SUMMARIES)
trunc = 4 max_src_len = 4
max_tgt_len = 8
assert max_len_target > max_src_len # Truncated
assert max_len_source > max_src_len
src_lang, tgt_lang = "ro_RO", "de_DE" # NOT WHAT IT WAS TRAINED ON src_lang, tgt_lang = "ro_RO", "de_DE" # NOT WHAT IT WAS TRAINED ON
train_dataset = MBartDataset( train_dataset = MBartDataset(
tokenizer, tokenizer,
data_dir=tmp_dir, data_dir=tmp_dir,
type_path="train", type_path="train",
max_source_length=trunc, max_source_length=max_src_len,
max_target_length=1000, # ignored max_target_length=max_tgt_len, # ignored
src_lang=src_lang, src_lang=src_lang,
tgt_lang=tgt_lang, tgt_lang=tgt_lang,
) )
...@@ -316,17 +319,15 @@ def test_mbart_dataset_truncation(): ...@@ -316,17 +319,15 @@ def test_mbart_dataset_truncation():
assert isinstance(batch, dict) assert isinstance(batch, dict)
assert batch["attention_mask"].shape == batch["input_ids"].shape assert batch["attention_mask"].shape == batch["input_ids"].shape
# show that articles were trimmed. # show that articles were trimmed.
assert batch["input_ids"].shape[1] == trunc assert batch["input_ids"].shape[1] == max_src_len
# show that targets are the same len # show that targets are the same len
assert batch["decoder_input_ids"].shape[1] == trunc assert batch["decoder_input_ids"].shape[1] == max_tgt_len
# check language codes in correct place # check language codes in correct place
assert batch["decoder_input_ids"][0, 0].item() == tokenizer.lang_code_to_id[tgt_lang] assert batch["decoder_input_ids"][0, 0].item() == tokenizer.lang_code_to_id[tgt_lang]
assert batch["decoder_input_ids"][0, -1].item() == tokenizer.eos_token_id assert batch["decoder_input_ids"][0, -1].item() == tokenizer.eos_token_id
assert batch["input_ids"][0, -2].item() == tokenizer.eos_token_id assert batch["input_ids"][0, -2].item() == tokenizer.eos_token_id
assert batch["input_ids"][0, -1].item() == tokenizer.lang_code_to_id[src_lang] assert batch["input_ids"][0, -1].item() == tokenizer.lang_code_to_id[src_lang]
assert max_len_target > trunc # Truncated
assert max_len_source > trunc
break # No need to test every batch break # No need to test every batch
......
...@@ -157,7 +157,8 @@ class MBartDataset(Seq2SeqDataset): ...@@ -157,7 +157,8 @@ class MBartDataset(Seq2SeqDataset):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
if self.max_source_length != self.max_target_length: if self.max_source_length != self.max_target_length:
warnings.warn( warnings.warn(
f"Mbart will ignore max_target_length = {self.max_target_length} and use {self.max_source_length} for both sides." f"Mbart is using sequence lengths {self.max_source_length}, {self.max_target_length}. "
f"Imbalanced sequence lengths may be undesired for translation tasks"
) )
def __getitem__(self, index) -> Dict[str, str]: def __getitem__(self, index) -> Dict[str, str]:
...@@ -178,6 +179,7 @@ class MBartDataset(Seq2SeqDataset): ...@@ -178,6 +179,7 @@ class MBartDataset(Seq2SeqDataset):
tgt_texts=[x["tgt_texts"] for x in batch], tgt_texts=[x["tgt_texts"] for x in batch],
tgt_lang=self.tgt_lang, tgt_lang=self.tgt_lang,
max_length=self.max_source_length, max_length=self.max_source_length,
max_target_length=self.max_target_length,
) )
return batch_encoding.data return batch_encoding.data
......
...@@ -193,6 +193,7 @@ class MBartTokenizer(XLMRobertaTokenizer): ...@@ -193,6 +193,7 @@ class MBartTokenizer(XLMRobertaTokenizer):
tgt_texts: Optional[List[str]] = None, tgt_texts: Optional[List[str]] = None,
tgt_lang: str = "ro_RO", tgt_lang: str = "ro_RO",
max_length: Optional[int] = None, max_length: Optional[int] = None,
max_target_length: Optional[int] = None,
padding: str = "longest", padding: str = "longest",
return_tensors: str = "pt", return_tensors: str = "pt",
**kwargs, **kwargs,
...@@ -224,13 +225,16 @@ class MBartTokenizer(XLMRobertaTokenizer): ...@@ -224,13 +225,16 @@ class MBartTokenizer(XLMRobertaTokenizer):
) )
if tgt_texts is None: if tgt_texts is None:
return model_inputs return model_inputs
# Process tgt_texts
if max_target_length is None:
max_target_length = max_length
self.set_tgt_lang_special_tokens(tgt_lang) self.set_tgt_lang_special_tokens(tgt_lang)
decoder_inputs: BatchEncoding = self( decoder_inputs: BatchEncoding = self(
tgt_texts, tgt_texts,
add_special_tokens=True, add_special_tokens=True,
return_tensors=return_tensors, return_tensors=return_tensors,
padding=padding, padding=padding,
max_length=max_length, max_length=max_target_length,
truncation=True, truncation=True,
**kwargs, **kwargs,
) )
......
...@@ -137,6 +137,18 @@ class MBartEnroIntegrationTest(unittest.TestCase): ...@@ -137,6 +137,18 @@ class MBartEnroIntegrationTest(unittest.TestCase):
self.assertEqual(self.tokenizer.prefix_tokens, []) self.assertEqual(self.tokenizer.prefix_tokens, [])
self.assertEqual(self.tokenizer.suffix_tokens, [self.tokenizer.eos_token_id, EN_CODE]) self.assertEqual(self.tokenizer.suffix_tokens, [self.tokenizer.eos_token_id, EN_CODE])
def test_max_target_length(self):
batch = self.tokenizer.prepare_translation_batch(
self.src_text, tgt_texts=self.tgt_text, max_length=3, max_target_length=10
)
self.assertEqual(batch.input_ids.shape[1], 3)
self.assertEqual(batch.decoder_input_ids.shape[1], 10)
# max_target_length will default to max_length if not specified
batch = self.tokenizer.prepare_translation_batch(self.src_text, tgt_texts=self.tgt_text, max_length=3)
self.assertEqual(batch.input_ids.shape[1], 3)
self.assertEqual(batch.decoder_input_ids.shape[1], 3)
def test_enro_tokenizer_batch_encode_plus(self): def test_enro_tokenizer_batch_encode_plus(self):
ids = self.tokenizer.batch_encode_plus(self.src_text).input_ids[0] ids = self.tokenizer.batch_encode_plus(self.src_text).input_ids[0]
self.assertListEqual(self.expected_src_tokens, ids) self.assertListEqual(self.expected_src_tokens, ids)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment