Unverified Commit 6f840990 authored by Théo Matussière's avatar Théo Matussière Committed by GitHub
Browse files

split seq2seq script into summarization & translation (#10611)



* split seq2seq script, update docs

* needless diff

* fix readme

* remove test diff

* s/summarization/translation
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* cr

* fix arguments & better mbart/t5 refs

* copyright
Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>

* reword readme
Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>

* s/summarization/translation

* short script names

* fix tests

* fix isort, include mbart doc

* delete old script, update tests

* automate source prefix

* automate source prefix for translation

* s/translation/trans
Co-authored-by: default avatarStas Bekman <stas00@users.noreply.github.com>

* fix script name (short version)

* typos
Co-authored-by: default avatarStas Bekman <stas00@users.noreply.github.com>

* exact parameter
Co-authored-by: default avatarStas Bekman <stas00@users.noreply.github.com>

* remove superfluous source_prefix calls in docs

* rename scripts & warn for source prefix

* black

* flake8
Co-authored-by: default avatartheo <theo@matussie.re>
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>
Co-authored-by: default avatarStas Bekman <stas00@users.noreply.github.com>
parent 505494a8
...@@ -168,13 +168,13 @@ Here is an example of how this can be used on a filesystem that is shared betwee ...@@ -168,13 +168,13 @@ Here is an example of how this can be used on a filesystem that is shared betwee
On the instance with the normal network run your program which will download and cache models (and optionally datasets if you use 🤗 Datasets). For example: On the instance with the normal network run your program which will download and cache models (and optionally datasets if you use 🤗 Datasets). For example:
``` ```
python examples/seq2seq/run_seq2seq.py --model_name_or_path t5-small --dataset_name wmt16 --dataset_config ro-en ... python examples/seq2seq/run_translation.py --model_name_or_path t5-small --dataset_name wmt16 --dataset_config ro-en ...
``` ```
and then with the same filesystem you can now run the same program on a firewalled instance: and then with the same filesystem you can now run the same program on a firewalled instance:
``` ```
HF_DATASETS_OFFLINE=1 TRANSFORMERS_OFFLINE=1 \ HF_DATASETS_OFFLINE=1 TRANSFORMERS_OFFLINE=1 \
python examples/seq2seq/run_seq2seq.py --model_name_or_path t5-small --dataset_name wmt16 --dataset_config ro-en ... python examples/seq2seq/run_translation.py --model_name_or_path t5-small --dataset_name wmt16 --dataset_config ro-en ...
``` ```
and it should succeed without any hanging waiting to timeout. and it should succeed without any hanging waiting to timeout.
......
...@@ -279,16 +279,16 @@ To deploy this feature: ...@@ -279,16 +279,16 @@ To deploy this feature:
and make sure you have added the distributed launcher ``-m torch.distributed.launch and make sure you have added the distributed launcher ``-m torch.distributed.launch
--nproc_per_node=NUMBER_OF_GPUS_YOU_HAVE`` if you haven't been using it already. --nproc_per_node=NUMBER_OF_GPUS_YOU_HAVE`` if you haven't been using it already.
For example here is how you could use it for ``run_seq2seq.py`` with 2 GPUs: For example here is how you could use it for ``run_translation.py`` with 2 GPUs:
.. code-block:: bash .. code-block:: bash
python -m torch.distributed.launch --nproc_per_node=2 examples/seq2seq/run_seq2seq.py \ python -m torch.distributed.launch --nproc_per_node=2 examples/seq2seq/run_translation.py \
--model_name_or_path t5-small --per_device_train_batch_size 1 \ --model_name_or_path t5-small --per_device_train_batch_size 1 \
--output_dir output_dir --overwrite_output_dir \ --output_dir output_dir --overwrite_output_dir \
--do_train --max_train_samples 500 --num_train_epochs 1 \ --do_train --max_train_samples 500 --num_train_epochs 1 \
--dataset_name wmt16 --dataset_config "ro-en" \ --dataset_name wmt16 --dataset_config "ro-en" \
--task translation_en_to_ro --source_prefix "translate English to Romanian: " \ --source_lang en --target_lang ro \
--fp16 --sharded_ddp simple --fp16 --sharded_ddp simple
Notes: Notes:
...@@ -304,16 +304,16 @@ Notes: ...@@ -304,16 +304,16 @@ Notes:
to the command line arguments, and make sure you have added the distributed launcher ``-m torch.distributed.launch to the command line arguments, and make sure you have added the distributed launcher ``-m torch.distributed.launch
--nproc_per_node=NUMBER_OF_GPUS_YOU_HAVE`` if you haven't been using it already. --nproc_per_node=NUMBER_OF_GPUS_YOU_HAVE`` if you haven't been using it already.
For example here is how you could use it for ``run_seq2seq.py`` with 2 GPUs: For example here is how you could use it for ``run_translation.py`` with 2 GPUs:
.. code-block:: bash .. code-block:: bash
python -m torch.distributed.launch --nproc_per_node=2 examples/seq2seq/run_seq2seq.py \ python -m torch.distributed.launch --nproc_per_node=2 examples/seq2seq/run_translation.py \
--model_name_or_path t5-small --per_device_train_batch_size 1 \ --model_name_or_path t5-small --per_device_train_batch_size 1 \
--output_dir output_dir --overwrite_output_dir \ --output_dir output_dir --overwrite_output_dir \
--do_train --max_train_samples 500 --num_train_epochs 1 \ --do_train --max_train_samples 500 --num_train_epochs 1 \
--dataset_name wmt16 --dataset_config "ro-en" \ --dataset_name wmt16 --dataset_config "ro-en" \
--task translation_en_to_ro --source_prefix "translate English to Romanian: " \ --source_lang en --target_lang ro \
--fp16 --sharded_ddp zero_dp_2 --fp16 --sharded_ddp zero_dp_2
:obj:`zero_dp_2` is an optimized version of the simple wrapper, while :obj:`zero_dp_3` fully shards model weights, :obj:`zero_dp_2` is an optimized version of the simple wrapper, while :obj:`zero_dp_3` fully shards model weights,
...@@ -333,7 +333,7 @@ Notes: ...@@ -333,7 +333,7 @@ Notes:
Known caveats: Known caveats:
- This feature is incompatible with :obj:`--predict_with_generate` in the `run_seq2seq.py` script. - This feature is incompatible with :obj:`--predict_with_generate` in the `run_translation.py` script.
- Using :obj:`--sharded_ddp zero_dp_3` requires wrapping each layer of the model in the special container - Using :obj:`--sharded_ddp zero_dp_3` requires wrapping each layer of the model in the special container
:obj:`FullyShardedDataParallelism` of fairscale. It should be used with the option :obj:`auto_wrap` if you are not :obj:`FullyShardedDataParallelism` of fairscale. It should be used with the option :obj:`auto_wrap` if you are not
doing this yourself: :obj:`--sharded_ddp "zero_dp_3 auto_wrap"`. doing this yourself: :obj:`--sharded_ddp "zero_dp_3 auto_wrap"`.
...@@ -402,17 +402,17 @@ In fact, you can continue using ``-m torch.distributed.launch`` with DeepSpeed a ...@@ -402,17 +402,17 @@ In fact, you can continue using ``-m torch.distributed.launch`` with DeepSpeed a
the ``deepspeed`` launcher. But since in the DeepSpeed documentation it'll be used everywhere, for consistency we will the ``deepspeed`` launcher. But since in the DeepSpeed documentation it'll be used everywhere, for consistency we will
use it here as well. use it here as well.
Here is an example of running ``run_seq2seq.py`` under DeepSpeed deploying all available GPUs: Here is an example of running ``run_translation.py`` under DeepSpeed deploying all available GPUs:
.. code-block:: bash .. code-block:: bash
deepspeed examples/seq2seq/run_seq2seq.py \ deepspeed examples/seq2seq/run_translation.py \
--deepspeed examples/tests/deepspeed/ds_config.json \ --deepspeed examples/tests/deepspeed/ds_config.json \
--model_name_or_path t5-small --per_device_train_batch_size 1 \ --model_name_or_path t5-small --per_device_train_batch_size 1 \
--output_dir output_dir --overwrite_output_dir --fp16 \ --output_dir output_dir --overwrite_output_dir --fp16 \
--do_train --max_train_samples 500 --num_train_epochs 1 \ --do_train --max_train_samples 500 --num_train_epochs 1 \
--dataset_name wmt16 --dataset_config "ro-en" \ --dataset_name wmt16 --dataset_config "ro-en" \
--task translation_en_to_ro --source_prefix "translate English to Romanian: " --source_lang en --target_lang ro
Note that in the DeepSpeed documentation you are likely to see ``--deepspeed --deepspeed_config ds_config.json`` - i.e. Note that in the DeepSpeed documentation you are likely to see ``--deepspeed --deepspeed_config ds_config.json`` - i.e.
...@@ -431,13 +431,13 @@ To deploy DeepSpeed with one GPU adjust the :class:`~transformers.Trainer` comma ...@@ -431,13 +431,13 @@ To deploy DeepSpeed with one GPU adjust the :class:`~transformers.Trainer` comma
.. code-block:: bash .. code-block:: bash
deepspeed --num_gpus=1 examples/seq2seq/run_seq2seq.py \ deepspeed --num_gpus=1 examples/seq2seq/run_translation.py \
--deepspeed examples/tests/deepspeed/ds_config.json \ --deepspeed examples/tests/deepspeed/ds_config.json \
--model_name_or_path t5-small --per_device_train_batch_size 1 \ --model_name_or_path t5-small --per_device_train_batch_size 1 \
--output_dir output_dir --overwrite_output_dir --fp16 \ --output_dir output_dir --overwrite_output_dir --fp16 \
--do_train --max_train_samples 500 --num_train_epochs 1 \ --do_train --max_train_samples 500 --num_train_epochs 1 \
--dataset_name wmt16 --dataset_config "ro-en" \ --dataset_name wmt16 --dataset_config "ro-en" \
--task translation_en_to_ro --source_prefix "translate English to Romanian: " --source_lang en --target_lang ro
This is almost the same as with multiple-GPUs, but here we tell DeepSpeed explicitly to use just one GPU. By default, This is almost the same as with multiple-GPUs, but here we tell DeepSpeed explicitly to use just one GPU. By default,
DeepSpeed deploys all GPUs it can see. If you have only 1 GPU to start with, then you don't need this argument. The DeepSpeed deploys all GPUs it can see. If you have only 1 GPU to start with, then you don't need this argument. The
...@@ -483,7 +483,7 @@ Notes: ...@@ -483,7 +483,7 @@ Notes:
.. code-block:: bash .. code-block:: bash
deepspeed --include localhost:1 examples/seq2seq/run_seq2seq.py ... deepspeed --include localhost:1 examples/seq2seq/run_translation.py ...
In this example, we tell DeepSpeed to use GPU 1 (second gpu). In this example, we tell DeepSpeed to use GPU 1 (second gpu).
...@@ -574,7 +574,7 @@ with: ...@@ -574,7 +574,7 @@ with:
.. code-block:: .. code-block::
!deepspeed examples/seq2seq/run_seq2seq.py ... !deepspeed examples/seq2seq/run_translation.py ...
or with bash magic, where you can write a multi-line code for the shell to run: or with bash magic, where you can write a multi-line code for the shell to run:
...@@ -583,7 +583,7 @@ or with bash magic, where you can write a multi-line code for the shell to run: ...@@ -583,7 +583,7 @@ or with bash magic, where you can write a multi-line code for the shell to run:
%%bash %%bash
cd /somewhere cd /somewhere
deepspeed examples/seq2seq/run_seq2seq.py ... deepspeed examples/seq2seq/run_translation.py ...
......
...@@ -742,8 +742,8 @@ Summarization ...@@ -742,8 +742,8 @@ Summarization
----------------------------------------------------------------------------------------------------------------------- -----------------------------------------------------------------------------------------------------------------------
Summarization is the task of summarizing a document or an article into a shorter text. If you would like to fine-tune a Summarization is the task of summarizing a document or an article into a shorter text. If you would like to fine-tune a
model on a summarization task, you may leverage the `run_seq2seq.py model on a summarization task, you may leverage the `run_summarization.py
<https://github.com/huggingface/transformers/tree/master/examples/seq2seq/run_seq2seq.py>`__ script. <https://github.com/huggingface/transformers/tree/master/examples/seq2seq/run_summarization.py>`__ script.
An example of a summarization dataset is the CNN / Daily Mail dataset, which consists of long news articles and was An example of a summarization dataset is the CNN / Daily Mail dataset, which consists of long news articles and was
created for the task of summarization. If you would like to fine-tune a model on a summarization task, various created for the task of summarization. If you would like to fine-tune a model on a summarization task, various
...@@ -822,8 +822,8 @@ Translation ...@@ -822,8 +822,8 @@ Translation
----------------------------------------------------------------------------------------------------------------------- -----------------------------------------------------------------------------------------------------------------------
Translation is the task of translating a text from one language to another. If you would like to fine-tune a model on a Translation is the task of translating a text from one language to another. If you would like to fine-tune a model on a
translation task, you may leverage the `run_seq2seq.py translation task, you may leverage the `run_translation.py
<https://github.com/huggingface/transformers/tree/master/examples/seq2seq/run_seq2seq.py>`__ script. <https://github.com/huggingface/transformers/tree/master/examples/seq2seq/run_translation.py>`__ script.
An example of a translation dataset is the WMT English to German dataset, which has sentences in English as the input An example of a translation dataset is the WMT English to German dataset, which has sentences in English as the input
data and the corresponding sentences in German as the target data. If you would like to fine-tune a model on a data and the corresponding sentences in German as the target data. If you would like to fine-tune a model on a
......
...@@ -30,7 +30,7 @@ For the old `finetune_trainer.py` and related utils, see [`examples/legacy/seq2s ...@@ -30,7 +30,7 @@ For the old `finetune_trainer.py` and related utils, see [`examples/legacy/seq2s
- `FSMTForConditionalGeneration` (translation only) - `FSMTForConditionalGeneration` (translation only)
- `T5ForConditionalGeneration` - `T5ForConditionalGeneration`
`run_seq2seq.py` is a lightweight example of how to download and preprocess a dataset from the [🤗 Datasets](https://github.com/huggingface/datasets) library or use your own files (jsonlines or csv), then fine-tune one of the architectures above on it. `run_summarization.py` and `run_translation.py` are lightweight examples of how to download and preprocess a dataset from the [🤗 Datasets](https://github.com/huggingface/datasets) library or use your own files (jsonlines or csv), then fine-tune one of the architectures above on it.
For custom datasets in `jsonlines` format please see: https://huggingface.co/docs/datasets/loading_datasets.html#json-files For custom datasets in `jsonlines` format please see: https://huggingface.co/docs/datasets/loading_datasets.html#json-files
and you also will find examples of these below. and you also will find examples of these below.
...@@ -39,11 +39,10 @@ and you also will find examples of these below. ...@@ -39,11 +39,10 @@ and you also will find examples of these below.
Here is an example on a summarization task: Here is an example on a summarization task:
```bash ```bash
python examples/seq2seq/run_seq2seq.py \ python examples/seq2seq/run_summarization.py \
--model_name_or_path t5-small \ --model_name_or_path t5-small \
--do_train \ --do_train \
--do_eval \ --do_eval \
--task summarization \
--dataset_name xsum \ --dataset_name xsum \
--output_dir /tmp/tst-summarization \ --output_dir /tmp/tst-summarization \
--per_device_train_batch_size=4 \ --per_device_train_batch_size=4 \
...@@ -60,11 +59,10 @@ And here is how you would use it on your own files, after adjusting the values f ...@@ -60,11 +59,10 @@ And here is how you would use it on your own files, after adjusting the values f
`--train_file`, `--validation_file`, `--text_column` and `--summary_column` to match your setup: `--train_file`, `--validation_file`, `--text_column` and `--summary_column` to match your setup:
```bash ```bash
python examples/seq2seq/run_seq2seq.py \ python examples/seq2seq/run_summarization.py \
--model_name_or_path t5-small \ --model_name_or_path t5-small \
--do_train \ --do_train \
--do_eval \ --do_eval \
--task summarization \
--train_file path_to_csv_or_jsonlines_file \ --train_file path_to_csv_or_jsonlines_file \
--validation_file path_to_csv_or_jsonlines_file \ --validation_file path_to_csv_or_jsonlines_file \
--output_dir /tmp/tst-summarization \ --output_dir /tmp/tst-summarization \
...@@ -140,14 +138,14 @@ And as with the CSV files, you can specify which values to select from the file, ...@@ -140,14 +138,14 @@ And as with the CSV files, you can specify which values to select from the file,
Here is an example of a translation fine-tuning with T5: Here is an example of a translation fine-tuning with T5:
```bash ```bash
python examples/seq2seq/run_seq2seq.py \ python examples/seq2seq/run_translation.py \
--model_name_or_path t5-small \ --model_name_or_path t5-small \
--do_train \ --do_train \
--do_eval \ --do_eval \
--task translation_en_to_ro \ --source_lang en \
--target_lang ro \
--dataset_name wmt16 \ --dataset_name wmt16 \
--dataset_config_name ro-en \ --dataset_config_name ro-en \
--source_prefix "translate English to Romanian: " \
--output_dir /tmp/tst-translation \ --output_dir /tmp/tst-translation \
--per_device_train_batch_size=4 \ --per_device_train_batch_size=4 \
--per_device_eval_batch_size=4 \ --per_device_eval_batch_size=4 \
...@@ -160,11 +158,10 @@ python examples/seq2seq/run_seq2seq.py \ ...@@ -160,11 +158,10 @@ python examples/seq2seq/run_seq2seq.py \
And the same with MBart: And the same with MBart:
```bash ```bash
python examples/seq2seq/run_seq2seq.py \ python examples/seq2seq/run_translation.py \
--model_name_or_path facebook/mbart-large-en-ro \ --model_name_or_path facebook/mbart-large-en-ro \
--do_train \ --do_train \
--do_eval \ --do_eval \
--task translation_en_to_ro \
--dataset_name wmt16 \ --dataset_name wmt16 \
--dataset_config_name ro-en \ --dataset_config_name ro-en \
--source_lang en_XX \ --source_lang en_XX \
...@@ -180,18 +177,8 @@ python examples/seq2seq/run_seq2seq.py \ ...@@ -180,18 +177,8 @@ python examples/seq2seq/run_seq2seq.py \
Note, that depending on the used model additional language-specific command-line arguments are sometimes required. Specifically: Note, that depending on the used model additional language-specific command-line arguments are sometimes required. Specifically:
* MBart models require: * MBart models require different `--{source,target}_lang` values, e.g. in place of `en` it expects `en_XX`, for `ro` it expects `ro_RO`. The full MBart specification for language codes can be looked up [here](https://huggingface.co/facebook/mbart-large-cc25)
``` * T5 models can use a `--source_prefix` argument to override the otherwise automated prefix of the form `translate {source_lang} to {target_lang}` for `run_translation.py` and `summarize: ` for `run_summarization.py`
--source_lang en_XX \
--target_lang ro_RO \
```
* T5 requires:
```
--source_prefix "translate English to Romanian: "
```
* yet, other models, require neither.
Also, if you switch to a different language pair, make sure to adjust the source and target values in all command line arguments. Also, if you switch to a different language pair, make sure to adjust the source and target values in all command line arguments.
...@@ -199,14 +186,14 @@ And here is how you would use the translation finetuning on your own files, afte ...@@ -199,14 +186,14 @@ And here is how you would use the translation finetuning on your own files, afte
values for the arguments `--train_file`, `--validation_file` to match your setup: values for the arguments `--train_file`, `--validation_file` to match your setup:
```bash ```bash
python examples/seq2seq/run_seq2seq.py \ python examples/seq2seq/run_translation.py \
--model_name_or_path t5-small \ --model_name_or_path t5-small \
--do_train \ --do_train \
--do_eval \ --do_eval \
--task translation_en_to_ro \ --source_lang en \
--target_lang ro \
--dataset_name wmt16 \ --dataset_name wmt16 \
--dataset_config_name ro-en \ --dataset_config_name ro-en \
--source_prefix "translate English to Romanian: " \
--train_file path_to_jsonlines_file \ --train_file path_to_jsonlines_file \
--validation_file path_to_jsonlines_file \ --validation_file path_to_jsonlines_file \
--output_dir /tmp/tst-translation \ --output_dir /tmp/tst-translation \
...@@ -229,13 +216,13 @@ Here the languages are Romanian (`ro`) and English (`en`). ...@@ -229,13 +216,13 @@ Here the languages are Romanian (`ro`) and English (`en`).
If you want to use a pre-processed dataset that leads to high bleu scores, but for the `en-de` language pair, you can use `--dataset_name wmt14-en-de-pre-processed`, as following: If you want to use a pre-processed dataset that leads to high bleu scores, but for the `en-de` language pair, you can use `--dataset_name wmt14-en-de-pre-processed`, as following:
```bash ```bash
python examples/seq2seq/run_seq2seq.py \ python examples/seq2seq/run_translation.py \
--model_name_or_path t5-small \ --model_name_or_path t5-small \
--do_train \ --do_train \
--do_eval \ --do_eval \
--task translation_en_to_de \ --source_lang en \
--target_lang de \
--dataset_name wmt14-en-de-pre-processed \ --dataset_name wmt14-en-de-pre-processed \
--source_prefix "translate English to German: " \
--output_dir /tmp/tst-translation \ --output_dir /tmp/tst-translation \
--per_device_train_batch_size=4 \ --per_device_train_batch_size=4 \
--per_device_eval_batch_size=4 \ --per_device_eval_batch_size=4 \
......
#!/usr/bin/env python #!/usr/bin/env python
# coding=utf-8 # coding=utf-8
# Copyright The HuggingFace Team and The HuggingFace Inc. team. All rights reserved. # Copyright 2021 The HuggingFace Team. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -20,7 +20,6 @@ Fine-tuning the library models for sequence to sequence. ...@@ -20,7 +20,6 @@ Fine-tuning the library models for sequence to sequence.
import logging import logging
import os import os
import re
import sys import sys
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Optional from typing import Optional
...@@ -37,8 +36,6 @@ from transformers import ( ...@@ -37,8 +36,6 @@ from transformers import (
AutoTokenizer, AutoTokenizer,
DataCollatorForSeq2Seq, DataCollatorForSeq2Seq,
HfArgumentParser, HfArgumentParser,
MBartTokenizer,
MBartTokenizerFast,
Seq2SeqTrainer, Seq2SeqTrainer,
Seq2SeqTrainingArguments, Seq2SeqTrainingArguments,
default_data_collator, default_data_collator,
...@@ -103,13 +100,6 @@ class DataTrainingArguments: ...@@ -103,13 +100,6 @@ class DataTrainingArguments:
Arguments pertaining to what data we are going to input our model for training and eval. Arguments pertaining to what data we are going to input our model for training and eval.
""" """
task: str = field(
default="summarization",
metadata={
"help": "The name of the task, should be summarization (or summarization_{dataset} for evaluating "
"pegasus) or translation (or translation_{xx}_to_{yy})."
},
)
dataset_name: Optional[str] = field( dataset_name: Optional[str] = field(
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
) )
...@@ -130,15 +120,14 @@ class DataTrainingArguments: ...@@ -130,15 +120,14 @@ class DataTrainingArguments:
validation_file: Optional[str] = field( validation_file: Optional[str] = field(
default=None, default=None,
metadata={ metadata={
"help": "An optional input evaluation data file to evaluate the metrics (rouge/sacreblue) on " "help": "An optional input evaluation data file to evaluate the metrics (rouge) on "
"(a jsonlines or csv file)." "(a jsonlines or csv file)."
}, },
) )
test_file: Optional[str] = field( test_file: Optional[str] = field(
default=None, default=None,
metadata={ metadata={
"help": "An optional input test data file to evaluate the metrics (rouge/sacreblue) on " "help": "An optional input test data file to evaluate the metrics (rouge) on " "(a jsonlines or csv file)."
"(a jsonlines or csv file)."
}, },
) )
overwrite_cache: bool = field( overwrite_cache: bool = field(
...@@ -200,8 +189,6 @@ class DataTrainingArguments: ...@@ -200,8 +189,6 @@ class DataTrainingArguments:
"value if set." "value if set."
}, },
) )
source_lang: Optional[str] = field(default=None, metadata={"help": "Source language id for translation."})
target_lang: Optional[str] = field(default=None, metadata={"help": "Target language id for translation."})
num_beams: Optional[int] = field( num_beams: Optional[int] = field(
default=None, default=None,
metadata={ metadata={
...@@ -229,10 +216,6 @@ class DataTrainingArguments: ...@@ -229,10 +216,6 @@ class DataTrainingArguments:
if self.validation_file is not None: if self.validation_file is not None:
extension = self.validation_file.split(".")[-1] extension = self.validation_file.split(".")[-1]
assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file." assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file."
if not self.task.startswith("summarization") and not self.task.startswith("translation"):
raise ValueError(
"`task` should be summarization, summarization_{dataset}, translation or translation_{xx}_to_{yy}."
)
if self.val_max_target_length is None: if self.val_max_target_length is None:
self.val_max_target_length = self.max_target_length self.val_max_target_length = self.max_target_length
...@@ -265,6 +248,18 @@ def main(): ...@@ -265,6 +248,18 @@ def main():
else: else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses() model_args, data_args, training_args = parser.parse_args_into_dataclasses()
if data_args.source_prefix is None and model_args.model_name_or_path in [
"t5-small",
"t5-base",
"t5-large",
"t5-3b",
"t5-11b",
]:
logger.warning(
"You're running a t5 model but didn't provide a source prefix, which is the expected, e.g. with "
"`--source_prefix 'summarize: ' `"
)
# Detecting last checkpoint. # Detecting last checkpoint.
last_checkpoint = None last_checkpoint = None
if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir: if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
...@@ -305,11 +300,8 @@ def main(): ...@@ -305,11 +300,8 @@ def main():
# or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/ # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
# (the dataset will be downloaded automatically from the datasets Hub). # (the dataset will be downloaded automatically from the datasets Hub).
# #
# For CSV/JSON files in the summarization task, this script will use the first column for the full texts and the # For CSV/JSON files this script will use the first column for the full texts and the second column for the
# second column for the summaries (unless you specify column names for this with the `text_column` and # summaries (unless you specify column names for this with the `text_column` and `summary_column` arguments).
# `summary_column` arguments).
# For translation, only JSON files are supported, with one field named "translation" containing two keys for the
# source and target languages (unless you adapt what follows).
# #
# In distributed training, the load_dataset function guarantee that only one local process can concurrently # In distributed training, the load_dataset function guarantee that only one local process can concurrently
# download the dataset. # download the dataset.
...@@ -358,16 +350,6 @@ def main(): ...@@ -358,16 +350,6 @@ def main():
use_auth_token=True if model_args.use_auth_token else None, use_auth_token=True if model_args.use_auth_token else None,
) )
# Set decoder_start_token_id
if model.config.decoder_start_token_id is None and isinstance(tokenizer, (MBartTokenizer, MBartTokenizerFast)):
assert (
data_args.target_lang is not None and data_args.source_lang is not None
), "mBart requires --target_lang and --source_lang"
if isinstance(tokenizer, MBartTokenizer):
model.config.decoder_start_token_id = tokenizer.lang_code_to_id[data_args.target_lang]
else:
model.config.decoder_start_token_id = tokenizer.convert_tokens_to_ids(data_args.target_lang)
if model.config.decoder_start_token_id is None: if model.config.decoder_start_token_id is None:
raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined") raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")
...@@ -385,55 +367,24 @@ def main(): ...@@ -385,55 +367,24 @@ def main():
logger.info("There is nothing to do. Please pass `do_train`, `do_eval` and/or `do_predict`.") logger.info("There is nothing to do. Please pass `do_train`, `do_eval` and/or `do_predict`.")
return return
# For translation we set the codes of our source and target languages (only useful for mBART, the others will # Get the column names for input/target.
# ignore those attributes). dataset_columns = summarization_name_mapping.get(data_args.dataset_name, None)
if data_args.task.startswith("translation") or isinstance(tokenizer, (MBartTokenizer, MBartTokenizerFast)): if data_args.text_column is None:
if data_args.source_lang is not None: text_column = dataset_columns[0] if dataset_columns is not None else column_names[0]
tokenizer.src_lang = data_args.source_lang
if data_args.target_lang is not None:
tokenizer.tgt_lang = data_args.target_lang
# To serialize preprocess_function below, each of those four variables needs to be defined (even if we won't use
# them all).
source_lang, target_lang, text_column, summary_column = None, None, None, None
if data_args.task.startswith("summarization"):
# Get the column names for input/target.
dataset_columns = summarization_name_mapping.get(data_args.dataset_name, None)
if data_args.text_column is None:
text_column = dataset_columns[0] if dataset_columns is not None else column_names[0]
else:
text_column = data_args.text_column
if text_column not in column_names:
raise ValueError(
f"--text_column' value '{data_args.text_column}' needs to be one of: {', '.join(column_names)}"
)
if data_args.summary_column is None:
summary_column = dataset_columns[1] if dataset_columns is not None else column_names[1]
else:
summary_column = data_args.summary_column
if summary_column not in column_names:
raise ValueError(
f"--summary_column' value '{data_args.summary_column}' needs to be one of: {', '.join(column_names)}"
)
else: else:
# Get the language codes for input/target. text_column = data_args.text_column
lang_search = re.match("translation_([a-z]+)_to_([a-z]+)", data_args.task) if text_column not in column_names:
if data_args.source_lang is not None: raise ValueError(
source_lang = data_args.source_lang.split("_")[0] f"--text_column' value '{data_args.text_column}' needs to be one of: {', '.join(column_names)}"
else: )
assert ( if data_args.summary_column is None:
lang_search is not None summary_column = dataset_columns[1] if dataset_columns is not None else column_names[1]
), "Provide a source language via --source_lang or rename your task 'translation_xx_to_yy'." else:
source_lang = lang_search.groups()[0] summary_column = data_args.summary_column
if summary_column not in column_names:
if data_args.target_lang is not None: raise ValueError(
target_lang = data_args.target_lang.split("_")[0] f"--summary_column' value '{data_args.summary_column}' needs to be one of: {', '.join(column_names)}"
else: )
assert (
lang_search is not None
), "Provide a target language via --target_lang or rename your task 'translation_xx_to_yy'."
target_lang = lang_search.groups()[1]
# Temporarily set max_target_length for training. # Temporarily set max_target_length for training.
max_target_length = data_args.max_target_length max_target_length = data_args.max_target_length
...@@ -446,12 +397,8 @@ def main(): ...@@ -446,12 +397,8 @@ def main():
) )
def preprocess_function(examples): def preprocess_function(examples):
if data_args.task.startswith("translation"): inputs = examples[text_column]
inputs = [ex[source_lang] for ex in examples["translation"]] targets = examples[summary_column]
targets = [ex[target_lang] for ex in examples["translation"]]
else:
inputs = examples[text_column]
targets = examples[summary_column]
inputs = [prefix + inp for inp in inputs] inputs = [prefix + inp for inp in inputs]
model_inputs = tokenizer(inputs, max_length=data_args.max_source_length, padding=padding, truncation=True) model_inputs = tokenizer(inputs, max_length=data_args.max_source_length, padding=padding, truncation=True)
...@@ -526,19 +473,15 @@ def main(): ...@@ -526,19 +473,15 @@ def main():
) )
# Metric # Metric
metric_name = "rouge" if data_args.task.startswith("summarization") else "sacrebleu" metric = load_metric("rouge")
metric = load_metric(metric_name)
def postprocess_text(preds, labels): def postprocess_text(preds, labels):
preds = [pred.strip() for pred in preds] preds = [pred.strip() for pred in preds]
labels = [label.strip() for label in labels] labels = [label.strip() for label in labels]
# rougeLSum expects newline after each sentence # rougeLSum expects newline after each sentence
if metric_name == "rouge": preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds]
preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds] labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels]
labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels]
else: # sacrebleu
labels = [[label] for label in labels]
return preds, labels return preds, labels
...@@ -555,13 +498,9 @@ def main(): ...@@ -555,13 +498,9 @@ def main():
# Some simple post-processing # Some simple post-processing
decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels) decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)
if metric_name == "rouge": result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True) # Extract a few results from ROUGE
# Extract a few results from ROUGE result = {key: value.mid.fmeasure * 100 for key, value in result.items()}
result = {key: value.mid.fmeasure * 100 for key, value in result.items()}
else:
result = metric.compute(predictions=decoded_preds, references=decoded_labels)
result = {"bleu": result["score"]}
prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds] prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds]
result["gen_len"] = np.mean(prediction_lens) result["gen_len"] = np.mean(prediction_lens)
...@@ -601,6 +540,7 @@ def main(): ...@@ -601,6 +540,7 @@ def main():
trainer.save_state() trainer.save_state()
# Evaluation # Evaluation
results = {}
if training_args.do_eval: if training_args.do_eval:
logger.info("*** Evaluate ***") logger.info("*** Evaluate ***")
...@@ -613,7 +553,6 @@ def main(): ...@@ -613,7 +553,6 @@ def main():
trainer.log_metrics("eval", metrics) trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics) trainer.save_metrics("eval", metrics)
# predict
if training_args.do_predict: if training_args.do_predict:
logger.info("*** Test ***") logger.info("*** Test ***")
...@@ -640,6 +579,8 @@ def main(): ...@@ -640,6 +579,8 @@ def main():
with open(output_test_preds_file, "w") as writer: with open(output_test_preds_file, "w") as writer:
writer.write("\n".join(test_preds)) writer.write("\n".join(test_preds))
return results
def _mp_fn(index): def _mp_fn(index):
# For xla_spawn (TPUs) # For xla_spawn (TPUs)
......
This diff is collapsed.
...@@ -49,8 +49,9 @@ if SRC_DIRS is not None: ...@@ -49,8 +49,9 @@ if SRC_DIRS is not None:
import run_mlm import run_mlm
import run_ner import run_ner
import run_qa as run_squad import run_qa as run_squad
import run_seq2seq import run_summarization
import run_swag import run_swag
import run_translation
logging.basicConfig(level=logging.DEBUG) logging.basicConfig(level=logging.DEBUG)
...@@ -277,15 +278,14 @@ class ExamplesTests(TestCasePlus): ...@@ -277,15 +278,14 @@ class ExamplesTests(TestCasePlus):
self.assertGreaterEqual(len(result[0]), 10) self.assertGreaterEqual(len(result[0]), 10)
@slow @slow
def test_run_seq2seq_summarization(self): def test_run_summarization(self):
stream_handler = logging.StreamHandler(sys.stdout) stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler) logger.addHandler(stream_handler)
tmp_dir = self.get_auto_remove_tmp_dir() tmp_dir = self.get_auto_remove_tmp_dir()
testargs = f""" testargs = f"""
run_seq2seq.py run_summarization.py
--model_name_or_path t5-small --model_name_or_path t5-small
--task summarization
--train_file tests/fixtures/tests_samples/xsum/sample.json --train_file tests/fixtures/tests_samples/xsum/sample.json
--validation_file tests/fixtures/tests_samples/xsum/sample.json --validation_file tests/fixtures/tests_samples/xsum/sample.json
--output_dir {tmp_dir} --output_dir {tmp_dir}
...@@ -301,7 +301,7 @@ class ExamplesTests(TestCasePlus): ...@@ -301,7 +301,7 @@ class ExamplesTests(TestCasePlus):
""".split() """.split()
with patch.object(sys, "argv", testargs): with patch.object(sys, "argv", testargs):
run_seq2seq.main() run_summarization.main()
result = get_results(tmp_dir) result = get_results(tmp_dir)
self.assertGreaterEqual(result["eval_rouge1"], 10) self.assertGreaterEqual(result["eval_rouge1"], 10)
self.assertGreaterEqual(result["eval_rouge2"], 2) self.assertGreaterEqual(result["eval_rouge2"], 2)
...@@ -309,15 +309,16 @@ class ExamplesTests(TestCasePlus): ...@@ -309,15 +309,16 @@ class ExamplesTests(TestCasePlus):
self.assertGreaterEqual(result["eval_rougeLsum"], 7) self.assertGreaterEqual(result["eval_rougeLsum"], 7)
@slow @slow
def test_run_seq2seq_translation(self): def test_run_translation(self):
stream_handler = logging.StreamHandler(sys.stdout) stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler) logger.addHandler(stream_handler)
tmp_dir = self.get_auto_remove_tmp_dir() tmp_dir = self.get_auto_remove_tmp_dir()
testargs = f""" testargs = f"""
run_seq2seq.py run_translation.py
--model_name_or_path sshleifer/student_marian_en_ro_6_1 --model_name_or_path sshleifer/student_marian_en_ro_6_1
--task translation_en_to_ro --source_lang en
--target_lang ro
--train_file tests/fixtures/tests_samples/wmt16/sample.json --train_file tests/fixtures/tests_samples/wmt16/sample.json
--validation_file tests/fixtures/tests_samples/wmt16/sample.json --validation_file tests/fixtures/tests_samples/wmt16/sample.json
--output_dir {tmp_dir} --output_dir {tmp_dir}
...@@ -335,6 +336,6 @@ class ExamplesTests(TestCasePlus): ...@@ -335,6 +336,6 @@ class ExamplesTests(TestCasePlus):
""".split() """.split()
with patch.object(sys, "argv", testargs): with patch.object(sys, "argv", testargs):
run_seq2seq.main() run_translation.main()
result = get_results(tmp_dir) result = get_results(tmp_dir)
self.assertGreaterEqual(result["eval_bleu"], 30) self.assertGreaterEqual(result["eval_bleu"], 30)
...@@ -233,7 +233,6 @@ class TestDeepSpeed(TestCasePlus): ...@@ -233,7 +233,6 @@ class TestDeepSpeed(TestCasePlus):
--group_by_length --group_by_length
--label_smoothing_factor 0.1 --label_smoothing_factor 0.1
--adafactor --adafactor
--task translation
--target_lang ro_RO --target_lang ro_RO
--source_lang en_XX --source_lang en_XX
""".split() """.split()
...@@ -246,7 +245,7 @@ class TestDeepSpeed(TestCasePlus): ...@@ -246,7 +245,7 @@ class TestDeepSpeed(TestCasePlus):
args = [x for x in args if x not in remove_args] args = [x for x in args if x not in remove_args]
ds_args = f"--deepspeed {self.test_file_dir_str}/ds_config.json".split() ds_args = f"--deepspeed {self.test_file_dir_str}/ds_config.json".split()
script = [f"{self.examples_dir_str}/seq2seq/run_seq2seq.py"] script = [f"{self.examples_dir_str}/seq2seq/run_translation.py"]
num_gpus = get_gpu_count() if distributed else 1 num_gpus = get_gpu_count() if distributed else 1
launcher = f"deepspeed --num_gpus {num_gpus}".split() launcher = f"deepspeed --num_gpus {num_gpus}".split()
......
...@@ -35,7 +35,7 @@ from transformers.trainer_utils import set_seed ...@@ -35,7 +35,7 @@ from transformers.trainer_utils import set_seed
bindir = os.path.abspath(os.path.dirname(__file__)) bindir = os.path.abspath(os.path.dirname(__file__))
sys.path.append(f"{bindir}/../../seq2seq") sys.path.append(f"{bindir}/../../seq2seq")
from run_seq2seq import main # noqa from run_translation import main # noqa
set_seed(42) set_seed(42)
...@@ -209,7 +209,6 @@ class TestTrainerExt(TestCasePlus): ...@@ -209,7 +209,6 @@ class TestTrainerExt(TestCasePlus):
--group_by_length --group_by_length
--label_smoothing_factor 0.1 --label_smoothing_factor 0.1
--adafactor --adafactor
--task translation
--target_lang ro_RO --target_lang ro_RO
--source_lang en_XX --source_lang en_XX
""" """
...@@ -226,12 +225,12 @@ class TestTrainerExt(TestCasePlus): ...@@ -226,12 +225,12 @@ class TestTrainerExt(TestCasePlus):
distributed_args = f""" distributed_args = f"""
-m torch.distributed.launch -m torch.distributed.launch
--nproc_per_node={n_gpu} --nproc_per_node={n_gpu}
{self.examples_dir_str}/seq2seq/run_seq2seq.py {self.examples_dir_str}/seq2seq/run_translation.py
""".split() """.split()
cmd = [sys.executable] + distributed_args + args cmd = [sys.executable] + distributed_args + args
execute_subprocess_async(cmd, env=self.get_env()) execute_subprocess_async(cmd, env=self.get_env())
else: else:
testargs = ["run_seq2seq.py"] + args testargs = ["run_translation.py"] + args
with patch.object(sys, "argv", testargs): with patch.object(sys, "argv", testargs):
main() main()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment