Unverified Commit 9a94bb8e authored by Edoardo Federici's avatar Edoardo Federici Committed by GitHub
Browse files

mBART support for run_summarization.py (#15125)

* Update run_summarization.py

* Fixed languages and added missing code

* fixed obj, docs, removed source_lang and target_lang

* make style, run_summarization.py reformatted
parent 97f3beed
...@@ -37,6 +37,10 @@ from transformers import ( ...@@ -37,6 +37,10 @@ from transformers import (
AutoTokenizer, AutoTokenizer,
DataCollatorForSeq2Seq, DataCollatorForSeq2Seq,
HfArgumentParser, HfArgumentParser,
MBart50Tokenizer,
MBart50TokenizerFast,
MBartTokenizer,
MBartTokenizerFast,
Seq2SeqTrainer, Seq2SeqTrainer,
Seq2SeqTrainingArguments, Seq2SeqTrainingArguments,
set_seed, set_seed,
...@@ -64,6 +68,9 @@ except (LookupError, OSError): ...@@ -64,6 +68,9 @@ except (LookupError, OSError):
with FileLock(".lock") as lock: with FileLock(".lock") as lock:
nltk.download("punkt", quiet=True) nltk.download("punkt", quiet=True)
# A list of all multilingual tokenizer which require lang attribute.
MULTILINGUAL_TOKENIZERS = [MBartTokenizer, MBartTokenizerFast, MBart50Tokenizer, MBart50TokenizerFast]
@dataclass @dataclass
class ModelArguments: class ModelArguments:
...@@ -114,6 +121,8 @@ class DataTrainingArguments: ...@@ -114,6 +121,8 @@ class DataTrainingArguments:
Arguments pertaining to what data we are going to input our model for training and eval. Arguments pertaining to what data we are going to input our model for training and eval.
""" """
lang: str = field(default=None, metadata={"help": "Language id for summarization."})
dataset_name: Optional[str] = field( dataset_name: Optional[str] = field(
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
) )
...@@ -217,12 +226,24 @@ class DataTrainingArguments: ...@@ -217,12 +226,24 @@ class DataTrainingArguments:
}, },
) )
source_prefix: Optional[str] = field( source_prefix: Optional[str] = field(
default=None, metadata={"help": "A prefix to add before every source text (useful for T5 models)."} default="", metadata={"help": "A prefix to add before every source text (useful for T5 models)."}
)
forced_bos_token: Optional[str] = field(
default=None,
metadata={
"help": "The token to force as the first generated token after the decoder_start_token_id."
"Useful for multilingual models like mBART where the first generated token"
"needs to be the target language token (Usually it is the target language token)"
},
) )
def __post_init__(self): def __post_init__(self):
if self.dataset_name is None and self.train_file is None and self.validation_file is None: if self.dataset_name is None and self.train_file is None and self.validation_file is None:
raise ValueError("Need either a dataset name or a training/validation file.") raise ValueError("Need either a dataset name or a training/validation file.")
elif self.lang is None:
raise ValueError("Need to specify the language.")
else: else:
if self.train_file is not None: if self.train_file is not None:
extension = self.train_file.split(".")[-1] extension = self.train_file.split(".")[-1]
...@@ -370,6 +391,12 @@ def main(): ...@@ -370,6 +391,12 @@ def main():
model.resize_token_embeddings(len(tokenizer)) model.resize_token_embeddings(len(tokenizer))
if model.config.decoder_start_token_id is None and isinstance(tokenizer, (MBartTokenizer, MBartTokenizerFast)):
if isinstance(tokenizer, MBartTokenizer):
model.config.decoder_start_token_id = tokenizer.lang_code_to_id[data_args.lang]
else:
model.config.decoder_start_token_id = tokenizer.convert_tokens_to_ids(data_args.lang)
if model.config.decoder_start_token_id is None: if model.config.decoder_start_token_id is None:
raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined") raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")
...@@ -406,6 +433,21 @@ def main(): ...@@ -406,6 +433,21 @@ def main():
logger.info("There is nothing to do. Please pass `do_train`, `do_eval` and/or `do_predict`.") logger.info("There is nothing to do. Please pass `do_train`, `do_eval` and/or `do_predict`.")
return return
if isinstance(tokenizer, tuple(MULTILINGUAL_TOKENIZERS)):
assert (
data_args.lang is not None
), f"{tokenizer.__class__.__name__} is a multilingual tokenizer which requires --lang argument"
tokenizer.src_lang = data_args.lang
tokenizer.tgt_lang = data_args.lang
# For multilingual translation models like mBART-50 and M2M100 we need to force the target language token
# as the first generated token. We ask the user to explicitly provide this as --forced_bos_token argument.
forced_bos_token_id = (
tokenizer.lang_code_to_id[data_args.forced_bos_token] if data_args.forced_bos_token is not None else None
)
model.config.forced_bos_token_id = forced_bos_token_id
# Get the column names for input/target. # Get the column names for input/target.
dataset_columns = summarization_name_mapping.get(data_args.dataset_name, None) dataset_columns = summarization_name_mapping.get(data_args.dataset_name, None)
if data_args.text_column is None: if data_args.text_column is None:
...@@ -436,14 +478,16 @@ def main(): ...@@ -436,14 +478,16 @@ def main():
) )
def preprocess_function(examples): def preprocess_function(examples):
# remove pairs where at least one record is None # remove pairs where at least one record is None
inputs, targets = [], [] inputs, targets = [], []
for i in range(len(examples[text_column])): for i in range(len(examples[text_column])):
if examples[text_column][i] is not None and examples[summary_column][i] is not None: if examples[text_column][i] is not None and examples[summary_column][i] is not None:
inputs.append(examples[text_column][i]) inputs.append(examples[text_column][i])
targets.append(examples[summary_column][i]) targets.append(examples[summary_column][i])
inputs = examples[text_column]
targets = examples[summary_column]
inputs = [prefix + inp for inp in inputs] inputs = [prefix + inp for inp in inputs]
model_inputs = tokenizer(inputs, max_length=data_args.max_source_length, padding=padding, truncation=True) model_inputs = tokenizer(inputs, max_length=data_args.max_source_length, padding=padding, truncation=True)
...@@ -637,6 +681,9 @@ def main(): ...@@ -637,6 +681,9 @@ def main():
else: else:
kwargs["dataset"] = data_args.dataset_name kwargs["dataset"] = data_args.dataset_name
if data_args.lang is not None:
kwargs["language"] = data_args.lang
if training_args.push_to_hub: if training_args.push_to_hub:
trainer.push_to_hub(**kwargs) trainer.push_to_hub(**kwargs)
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment