"tests/models/herbert/test_tokenization_herbert.py" did not exist on "31c23bd5ee26425a67f92fc170789656379252a6"
Unverified Commit 9a94bb8e authored by Edoardo Federici's avatar Edoardo Federici Committed by GitHub
Browse files

mBART support for run_summarization.py (#15125)

* Update run_summarization.py

* Fixed languages and added missing code

* fixed obj, docs, removed source_lang and target_lang

* make style, run_summarization.py reformatted
parent 97f3beed
......@@ -37,6 +37,10 @@ from transformers import (
AutoTokenizer,
DataCollatorForSeq2Seq,
HfArgumentParser,
MBart50Tokenizer,
MBart50TokenizerFast,
MBartTokenizer,
MBartTokenizerFast,
Seq2SeqTrainer,
Seq2SeqTrainingArguments,
set_seed,
......@@ -64,6 +68,9 @@ except (LookupError, OSError):
with FileLock(".lock") as lock:
nltk.download("punkt", quiet=True)
# A list of all multilingual tokenizer which require lang attribute.
MULTILINGUAL_TOKENIZERS = [MBartTokenizer, MBartTokenizerFast, MBart50Tokenizer, MBart50TokenizerFast]
@dataclass
class ModelArguments:
......@@ -114,6 +121,8 @@ class DataTrainingArguments:
Arguments pertaining to what data we are going to input our model for training and eval.
"""
lang: str = field(default=None, metadata={"help": "Language id for summarization."})
dataset_name: Optional[str] = field(
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
)
......@@ -217,12 +226,24 @@ class DataTrainingArguments:
},
)
source_prefix: Optional[str] = field(
default=None, metadata={"help": "A prefix to add before every source text (useful for T5 models)."}
default="", metadata={"help": "A prefix to add before every source text (useful for T5 models)."}
)
forced_bos_token: Optional[str] = field(
default=None,
metadata={
"help": "The token to force as the first generated token after the decoder_start_token_id."
"Useful for multilingual models like mBART where the first generated token"
"needs to be the target language token (Usually it is the target language token)"
},
)
def __post_init__(self):
if self.dataset_name is None and self.train_file is None and self.validation_file is None:
raise ValueError("Need either a dataset name or a training/validation file.")
elif self.lang is None:
raise ValueError("Need to specify the language.")
else:
if self.train_file is not None:
extension = self.train_file.split(".")[-1]
......@@ -370,6 +391,12 @@ def main():
model.resize_token_embeddings(len(tokenizer))
if model.config.decoder_start_token_id is None and isinstance(tokenizer, (MBartTokenizer, MBartTokenizerFast)):
if isinstance(tokenizer, MBartTokenizer):
model.config.decoder_start_token_id = tokenizer.lang_code_to_id[data_args.lang]
else:
model.config.decoder_start_token_id = tokenizer.convert_tokens_to_ids(data_args.lang)
if model.config.decoder_start_token_id is None:
raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")
......@@ -406,6 +433,21 @@ def main():
logger.info("There is nothing to do. Please pass `do_train`, `do_eval` and/or `do_predict`.")
return
if isinstance(tokenizer, tuple(MULTILINGUAL_TOKENIZERS)):
assert (
data_args.lang is not None
), f"{tokenizer.__class__.__name__} is a multilingual tokenizer which requires --lang argument"
tokenizer.src_lang = data_args.lang
tokenizer.tgt_lang = data_args.lang
# For multilingual translation models like mBART-50 and M2M100 we need to force the target language token
# as the first generated token. We ask the user to explicitly provide this as --forced_bos_token argument.
forced_bos_token_id = (
tokenizer.lang_code_to_id[data_args.forced_bos_token] if data_args.forced_bos_token is not None else None
)
model.config.forced_bos_token_id = forced_bos_token_id
# Get the column names for input/target.
dataset_columns = summarization_name_mapping.get(data_args.dataset_name, None)
if data_args.text_column is None:
......@@ -436,14 +478,16 @@ def main():
)
def preprocess_function(examples):
# remove pairs where at least one record is None
inputs, targets = [], []
for i in range(len(examples[text_column])):
if examples[text_column][i] is not None and examples[summary_column][i] is not None:
inputs.append(examples[text_column][i])
targets.append(examples[summary_column][i])
inputs = examples[text_column]
targets = examples[summary_column]
inputs = [prefix + inp for inp in inputs]
model_inputs = tokenizer(inputs, max_length=data_args.max_source_length, padding=padding, truncation=True)
......@@ -637,6 +681,9 @@ def main():
else:
kwargs["dataset"] = data_args.dataset_name
if data_args.lang is not None:
kwargs["language"] = data_args.lang
if training_args.push_to_hub:
trainer.push_to_hub(**kwargs)
else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment