Unverified Commit 6f840990 authored by Théo Matussière's avatar Théo Matussière Committed by GitHub
Browse files

split seq2seq script into summarization & translation (#10611)



* split seq2seq script, update docs

* needless diff

* fix readme

* remove test diff

* s/summarization/translation
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* cr

* fix arguments & better mbart/t5 refs

* copyright
Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>

* reword readme
Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>

* s/summarization/translation

* short script names

* fix tests

* fix isort, include mbart doc

* delete old script, update tests

* automate source prefix

* automate source prefix for translation

* s/translation/trans
Co-authored-by: default avatarStas Bekman <stas00@users.noreply.github.com>

* fix script name (short version)

* typos
Co-authored-by: default avatarStas Bekman <stas00@users.noreply.github.com>

* exact parameter
Co-authored-by: default avatarStas Bekman <stas00@users.noreply.github.com>

* remove superfluous source_prefix calls in docs

* rename scripts & warn for source prefix

* black

* flake8
Co-authored-by: default avatartheo <theo@matussie.re>
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>
Co-authored-by: default avatarStas Bekman <stas00@users.noreply.github.com>
parent 505494a8
......@@ -168,13 +168,13 @@ Here is an example of how this can be used on a filesystem that is shared betwee
On the instance with the normal network run your program which will download and cache models (and optionally datasets if you use 🤗 Datasets). For example:
```
python examples/seq2seq/run_seq2seq.py --model_name_or_path t5-small --dataset_name wmt16 --dataset_config ro-en ...
python examples/seq2seq/run_translation.py --model_name_or_path t5-small --dataset_name wmt16 --dataset_config ro-en ...
```
and then with the same filesystem you can now run the same program on a firewalled instance:
```
HF_DATASETS_OFFLINE=1 TRANSFORMERS_OFFLINE=1 \
python examples/seq2seq/run_seq2seq.py --model_name_or_path t5-small --dataset_name wmt16 --dataset_config ro-en ...
python examples/seq2seq/run_translation.py --model_name_or_path t5-small --dataset_name wmt16 --dataset_config ro-en ...
```
and it should succeed without any hanging waiting to timeout.
......
......@@ -279,16 +279,16 @@ To deploy this feature:
and make sure you have added the distributed launcher ``-m torch.distributed.launch
--nproc_per_node=NUMBER_OF_GPUS_YOU_HAVE`` if you haven't been using it already.
For example here is how you could use it for ``run_seq2seq.py`` with 2 GPUs:
For example here is how you could use it for ``run_translation.py`` with 2 GPUs:
.. code-block:: bash
python -m torch.distributed.launch --nproc_per_node=2 examples/seq2seq/run_seq2seq.py \
python -m torch.distributed.launch --nproc_per_node=2 examples/seq2seq/run_translation.py \
--model_name_or_path t5-small --per_device_train_batch_size 1 \
--output_dir output_dir --overwrite_output_dir \
--do_train --max_train_samples 500 --num_train_epochs 1 \
--dataset_name wmt16 --dataset_config "ro-en" \
--task translation_en_to_ro --source_prefix "translate English to Romanian: " \
--source_lang en --target_lang ro \
--fp16 --sharded_ddp simple
Notes:
......@@ -304,16 +304,16 @@ Notes:
to the command line arguments, and make sure you have added the distributed launcher ``-m torch.distributed.launch
--nproc_per_node=NUMBER_OF_GPUS_YOU_HAVE`` if you haven't been using it already.
For example here is how you could use it for ``run_seq2seq.py`` with 2 GPUs:
For example here is how you could use it for ``run_translation.py`` with 2 GPUs:
.. code-block:: bash
python -m torch.distributed.launch --nproc_per_node=2 examples/seq2seq/run_seq2seq.py \
python -m torch.distributed.launch --nproc_per_node=2 examples/seq2seq/run_translation.py \
--model_name_or_path t5-small --per_device_train_batch_size 1 \
--output_dir output_dir --overwrite_output_dir \
--do_train --max_train_samples 500 --num_train_epochs 1 \
--dataset_name wmt16 --dataset_config "ro-en" \
--task translation_en_to_ro --source_prefix "translate English to Romanian: " \
--source_lang en --target_lang ro \
--fp16 --sharded_ddp zero_dp_2
:obj:`zero_dp_2` is an optimized version of the simple wrapper, while :obj:`zero_dp_3` fully shards model weights,
......@@ -333,7 +333,7 @@ Notes:
Known caveats:
- This feature is incompatible with :obj:`--predict_with_generate` in the `run_seq2seq.py` script.
- This feature is incompatible with :obj:`--predict_with_generate` in the `run_translation.py` script.
- Using :obj:`--sharded_ddp zero_dp_3` requires wrapping each layer of the model in the special container
:obj:`FullyShardedDataParallelism` of fairscale. It should be used with the option :obj:`auto_wrap` if you are not
doing this yourself: :obj:`--sharded_ddp "zero_dp_3 auto_wrap"`.
......@@ -402,17 +402,17 @@ In fact, you can continue using ``-m torch.distributed.launch`` with DeepSpeed a
the ``deepspeed`` launcher. But since in the DeepSpeed documentation it'll be used everywhere, for consistency we will
use it here as well.
Here is an example of running ``run_seq2seq.py`` under DeepSpeed deploying all available GPUs:
Here is an example of running ``run_translation.py`` under DeepSpeed deploying all available GPUs:
.. code-block:: bash
deepspeed examples/seq2seq/run_seq2seq.py \
deepspeed examples/seq2seq/run_translation.py \
--deepspeed examples/tests/deepspeed/ds_config.json \
--model_name_or_path t5-small --per_device_train_batch_size 1 \
--output_dir output_dir --overwrite_output_dir --fp16 \
--do_train --max_train_samples 500 --num_train_epochs 1 \
--dataset_name wmt16 --dataset_config "ro-en" \
--task translation_en_to_ro --source_prefix "translate English to Romanian: "
--source_lang en --target_lang ro
Note that in the DeepSpeed documentation you are likely to see ``--deepspeed --deepspeed_config ds_config.json`` - i.e.
......@@ -431,13 +431,13 @@ To deploy DeepSpeed with one GPU adjust the :class:`~transformers.Trainer` comma
.. code-block:: bash
deepspeed --num_gpus=1 examples/seq2seq/run_seq2seq.py \
deepspeed --num_gpus=1 examples/seq2seq/run_translation.py \
--deepspeed examples/tests/deepspeed/ds_config.json \
--model_name_or_path t5-small --per_device_train_batch_size 1 \
--output_dir output_dir --overwrite_output_dir --fp16 \
--do_train --max_train_samples 500 --num_train_epochs 1 \
--dataset_name wmt16 --dataset_config "ro-en" \
--task translation_en_to_ro --source_prefix "translate English to Romanian: "
--source_lang en --target_lang ro
This is almost the same as with multiple-GPUs, but here we tell DeepSpeed explicitly to use just one GPU. By default,
DeepSpeed deploys all GPUs it can see. If you have only 1 GPU to start with, then you don't need this argument. The
......@@ -483,7 +483,7 @@ Notes:
.. code-block:: bash
deepspeed --include localhost:1 examples/seq2seq/run_seq2seq.py ...
deepspeed --include localhost:1 examples/seq2seq/run_translation.py ...
In this example, we tell DeepSpeed to use GPU 1 (second gpu).
......@@ -574,7 +574,7 @@ with:
.. code-block::
!deepspeed examples/seq2seq/run_seq2seq.py ...
!deepspeed examples/seq2seq/run_translation.py ...
or with bash magic, where you can write a multi-line code for the shell to run:
......@@ -583,7 +583,7 @@ or with bash magic, where you can write a multi-line code for the shell to run:
%%bash
cd /somewhere
deepspeed examples/seq2seq/run_seq2seq.py ...
deepspeed examples/seq2seq/run_translation.py ...
......
......@@ -742,8 +742,8 @@ Summarization
-----------------------------------------------------------------------------------------------------------------------
Summarization is the task of summarizing a document or an article into a shorter text. If you would like to fine-tune a
model on a summarization task, you may leverage the `run_seq2seq.py
<https://github.com/huggingface/transformers/tree/master/examples/seq2seq/run_seq2seq.py>`__ script.
model on a summarization task, you may leverage the `run_summarization.py
<https://github.com/huggingface/transformers/tree/master/examples/seq2seq/run_summarization.py>`__ script.
An example of a summarization dataset is the CNN / Daily Mail dataset, which consists of long news articles and was
created for the task of summarization. If you would like to fine-tune a model on a summarization task, various
......@@ -822,8 +822,8 @@ Translation
-----------------------------------------------------------------------------------------------------------------------
Translation is the task of translating a text from one language to another. If you would like to fine-tune a model on a
translation task, you may leverage the `run_seq2seq.py
<https://github.com/huggingface/transformers/tree/master/examples/seq2seq/run_seq2seq.py>`__ script.
translation task, you may leverage the `run_translation.py
<https://github.com/huggingface/transformers/tree/master/examples/seq2seq/run_translation.py>`__ script.
An example of a translation dataset is the WMT English to German dataset, which has sentences in English as the input
data and the corresponding sentences in German as the target data. If you would like to fine-tune a model on a
......
......@@ -30,7 +30,7 @@ For the old `finetune_trainer.py` and related utils, see [`examples/legacy/seq2s
- `FSMTForConditionalGeneration` (translation only)
- `T5ForConditionalGeneration`
`run_seq2seq.py` is a lightweight example of how to download and preprocess a dataset from the [🤗 Datasets](https://github.com/huggingface/datasets) library or use your own files (jsonlines or csv), then fine-tune one of the architectures above on it.
`run_summarization.py` and `run_translation.py` are lightweight examples of how to download and preprocess a dataset from the [🤗 Datasets](https://github.com/huggingface/datasets) library or use your own files (jsonlines or csv), then fine-tune one of the architectures above on it.
For custom datasets in `jsonlines` format please see: https://huggingface.co/docs/datasets/loading_datasets.html#json-files
and you also will find examples of these below.
......@@ -39,11 +39,10 @@ and you also will find examples of these below.
Here is an example on a summarization task:
```bash
python examples/seq2seq/run_seq2seq.py \
python examples/seq2seq/run_summarization.py \
--model_name_or_path t5-small \
--do_train \
--do_eval \
--task summarization \
--dataset_name xsum \
--output_dir /tmp/tst-summarization \
--per_device_train_batch_size=4 \
......@@ -60,11 +59,10 @@ And here is how you would use it on your own files, after adjusting the values f
`--train_file`, `--validation_file`, `--text_column` and `--summary_column` to match your setup:
```bash
python examples/seq2seq/run_seq2seq.py \
python examples/seq2seq/run_summarization.py \
--model_name_or_path t5-small \
--do_train \
--do_eval \
--task summarization \
--train_file path_to_csv_or_jsonlines_file \
--validation_file path_to_csv_or_jsonlines_file \
--output_dir /tmp/tst-summarization \
......@@ -140,14 +138,14 @@ And as with the CSV files, you can specify which values to select from the file,
Here is an example of a translation fine-tuning with T5:
```bash
python examples/seq2seq/run_seq2seq.py \
python examples/seq2seq/run_translation.py \
--model_name_or_path t5-small \
--do_train \
--do_eval \
--task translation_en_to_ro \
--source_lang en \
--target_lang ro \
--dataset_name wmt16 \
--dataset_config_name ro-en \
--source_prefix "translate English to Romanian: " \
--output_dir /tmp/tst-translation \
--per_device_train_batch_size=4 \
--per_device_eval_batch_size=4 \
......@@ -160,11 +158,10 @@ python examples/seq2seq/run_seq2seq.py \
And the same with MBart:
```bash
python examples/seq2seq/run_seq2seq.py \
python examples/seq2seq/run_translation.py \
--model_name_or_path facebook/mbart-large-en-ro \
--do_train \
--do_eval \
--task translation_en_to_ro \
--dataset_name wmt16 \
--dataset_config_name ro-en \
--source_lang en_XX \
......@@ -180,18 +177,8 @@ python examples/seq2seq/run_seq2seq.py \
Note, that depending on the used model additional language-specific command-line arguments are sometimes required. Specifically:
* MBart models require:
```
--source_lang en_XX \
--target_lang ro_RO \
```
* T5 requires:
```
--source_prefix "translate English to Romanian: "
```
* yet, other models, require neither.
* MBart models require different `--{source,target}_lang` values, e.g. in place of `en` it expects `en_XX`, for `ro` it expects `ro_RO`. The full MBart specification for language codes can be looked up [here](https://huggingface.co/facebook/mbart-large-cc25)
* T5 models can use a `--source_prefix` argument to override the otherwise automated prefix of the form `translate {source_lang} to {target_lang}` for `run_translation.py` and `summarize: ` for `run_summarization.py`
Also, if you switch to a different language pair, make sure to adjust the source and target values in all command line arguments.
......@@ -199,14 +186,14 @@ And here is how you would use the translation finetuning on your own files, afte
values for the arguments `--train_file`, `--validation_file` to match your setup:
```bash
python examples/seq2seq/run_seq2seq.py \
python examples/seq2seq/run_translation.py \
--model_name_or_path t5-small \
--do_train \
--do_eval \
--task translation_en_to_ro \
--source_lang en \
--target_lang ro \
--dataset_name wmt16 \
--dataset_config_name ro-en \
--source_prefix "translate English to Romanian: " \
--train_file path_to_jsonlines_file \
--validation_file path_to_jsonlines_file \
--output_dir /tmp/tst-translation \
......@@ -229,13 +216,13 @@ Here the languages are Romanian (`ro`) and English (`en`).
If you want to use a pre-processed dataset that leads to high bleu scores, but for the `en-de` language pair, you can use `--dataset_name wmt14-en-de-pre-processed`, as following:
```bash
python examples/seq2seq/run_seq2seq.py \
python examples/seq2seq/run_translation.py \
--model_name_or_path t5-small \
--do_train \
--do_eval \
--task translation_en_to_de \
--source_lang en \
--target_lang de \
--dataset_name wmt14-en-de-pre-processed \
--source_prefix "translate English to German: " \
--output_dir /tmp/tst-translation \
--per_device_train_batch_size=4 \
--per_device_eval_batch_size=4 \
......
#!/usr/bin/env python
# coding=utf-8
# Copyright The HuggingFace Team and The HuggingFace Inc. team. All rights reserved.
# Copyright 2021 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -20,7 +20,6 @@ Fine-tuning the library models for sequence to sequence.
import logging
import os
import re
import sys
from dataclasses import dataclass, field
from typing import Optional
......@@ -37,8 +36,6 @@ from transformers import (
AutoTokenizer,
DataCollatorForSeq2Seq,
HfArgumentParser,
MBartTokenizer,
MBartTokenizerFast,
Seq2SeqTrainer,
Seq2SeqTrainingArguments,
default_data_collator,
......@@ -103,13 +100,6 @@ class DataTrainingArguments:
Arguments pertaining to what data we are going to input our model for training and eval.
"""
task: str = field(
default="summarization",
metadata={
"help": "The name of the task, should be summarization (or summarization_{dataset} for evaluating "
"pegasus) or translation (or translation_{xx}_to_{yy})."
},
)
dataset_name: Optional[str] = field(
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
)
......@@ -130,15 +120,14 @@ class DataTrainingArguments:
validation_file: Optional[str] = field(
default=None,
metadata={
"help": "An optional input evaluation data file to evaluate the metrics (rouge/sacreblue) on "
"help": "An optional input evaluation data file to evaluate the metrics (rouge) on "
"(a jsonlines or csv file)."
},
)
test_file: Optional[str] = field(
default=None,
metadata={
"help": "An optional input test data file to evaluate the metrics (rouge/sacreblue) on "
"(a jsonlines or csv file)."
"help": "An optional input test data file to evaluate the metrics (rouge) on " "(a jsonlines or csv file)."
},
)
overwrite_cache: bool = field(
......@@ -200,8 +189,6 @@ class DataTrainingArguments:
"value if set."
},
)
source_lang: Optional[str] = field(default=None, metadata={"help": "Source language id for translation."})
target_lang: Optional[str] = field(default=None, metadata={"help": "Target language id for translation."})
num_beams: Optional[int] = field(
default=None,
metadata={
......@@ -229,10 +216,6 @@ class DataTrainingArguments:
if self.validation_file is not None:
extension = self.validation_file.split(".")[-1]
assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file."
if not self.task.startswith("summarization") and not self.task.startswith("translation"):
raise ValueError(
"`task` should be summarization, summarization_{dataset}, translation or translation_{xx}_to_{yy}."
)
if self.val_max_target_length is None:
self.val_max_target_length = self.max_target_length
......@@ -265,6 +248,18 @@ def main():
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
if data_args.source_prefix is None and model_args.model_name_or_path in [
"t5-small",
"t5-base",
"t5-large",
"t5-3b",
"t5-11b",
]:
logger.warning(
"You're running a t5 model but didn't provide a source prefix, which is the expected, e.g. with "
"`--source_prefix 'summarize: ' `"
)
# Detecting last checkpoint.
last_checkpoint = None
if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
......@@ -305,11 +300,8 @@ def main():
# or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
# (the dataset will be downloaded automatically from the datasets Hub).
#
# For CSV/JSON files in the summarization task, this script will use the first column for the full texts and the
# second column for the summaries (unless you specify column names for this with the `text_column` and
# `summary_column` arguments).
# For translation, only JSON files are supported, with one field named "translation" containing two keys for the
# source and target languages (unless you adapt what follows).
# For CSV/JSON files this script will use the first column for the full texts and the second column for the
# summaries (unless you specify column names for this with the `text_column` and `summary_column` arguments).
#
# In distributed training, the load_dataset function guarantee that only one local process can concurrently
# download the dataset.
......@@ -358,16 +350,6 @@ def main():
use_auth_token=True if model_args.use_auth_token else None,
)
# Set decoder_start_token_id
if model.config.decoder_start_token_id is None and isinstance(tokenizer, (MBartTokenizer, MBartTokenizerFast)):
assert (
data_args.target_lang is not None and data_args.source_lang is not None
), "mBart requires --target_lang and --source_lang"
if isinstance(tokenizer, MBartTokenizer):
model.config.decoder_start_token_id = tokenizer.lang_code_to_id[data_args.target_lang]
else:
model.config.decoder_start_token_id = tokenizer.convert_tokens_to_ids(data_args.target_lang)
if model.config.decoder_start_token_id is None:
raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")
......@@ -385,19 +367,6 @@ def main():
logger.info("There is nothing to do. Please pass `do_train`, `do_eval` and/or `do_predict`.")
return
# For translation we set the codes of our source and target languages (only useful for mBART, the others will
# ignore those attributes).
if data_args.task.startswith("translation") or isinstance(tokenizer, (MBartTokenizer, MBartTokenizerFast)):
if data_args.source_lang is not None:
tokenizer.src_lang = data_args.source_lang
if data_args.target_lang is not None:
tokenizer.tgt_lang = data_args.target_lang
# To serialize preprocess_function below, each of those four variables needs to be defined (even if we won't use
# them all).
source_lang, target_lang, text_column, summary_column = None, None, None, None
if data_args.task.startswith("summarization"):
# Get the column names for input/target.
dataset_columns = summarization_name_mapping.get(data_args.dataset_name, None)
if data_args.text_column is None:
......@@ -416,24 +385,6 @@ def main():
raise ValueError(
f"--summary_column' value '{data_args.summary_column}' needs to be one of: {', '.join(column_names)}"
)
else:
# Get the language codes for input/target.
lang_search = re.match("translation_([a-z]+)_to_([a-z]+)", data_args.task)
if data_args.source_lang is not None:
source_lang = data_args.source_lang.split("_")[0]
else:
assert (
lang_search is not None
), "Provide a source language via --source_lang or rename your task 'translation_xx_to_yy'."
source_lang = lang_search.groups()[0]
if data_args.target_lang is not None:
target_lang = data_args.target_lang.split("_")[0]
else:
assert (
lang_search is not None
), "Provide a target language via --target_lang or rename your task 'translation_xx_to_yy'."
target_lang = lang_search.groups()[1]
# Temporarily set max_target_length for training.
max_target_length = data_args.max_target_length
......@@ -446,10 +397,6 @@ def main():
)
def preprocess_function(examples):
if data_args.task.startswith("translation"):
inputs = [ex[source_lang] for ex in examples["translation"]]
targets = [ex[target_lang] for ex in examples["translation"]]
else:
inputs = examples[text_column]
targets = examples[summary_column]
inputs = [prefix + inp for inp in inputs]
......@@ -526,19 +473,15 @@ def main():
)
# Metric
metric_name = "rouge" if data_args.task.startswith("summarization") else "sacrebleu"
metric = load_metric(metric_name)
metric = load_metric("rouge")
def postprocess_text(preds, labels):
preds = [pred.strip() for pred in preds]
labels = [label.strip() for label in labels]
# rougeLSum expects newline after each sentence
if metric_name == "rouge":
preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds]
labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels]
else: # sacrebleu
labels = [[label] for label in labels]
return preds, labels
......@@ -555,13 +498,9 @@ def main():
# Some simple post-processing
decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)
if metric_name == "rouge":
result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
# Extract a few results from ROUGE
result = {key: value.mid.fmeasure * 100 for key, value in result.items()}
else:
result = metric.compute(predictions=decoded_preds, references=decoded_labels)
result = {"bleu": result["score"]}
prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds]
result["gen_len"] = np.mean(prediction_lens)
......@@ -601,6 +540,7 @@ def main():
trainer.save_state()
# Evaluation
results = {}
if training_args.do_eval:
logger.info("*** Evaluate ***")
......@@ -613,7 +553,6 @@ def main():
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)
# predict
if training_args.do_predict:
logger.info("*** Test ***")
......@@ -640,6 +579,8 @@ def main():
with open(output_test_preds_file, "w") as writer:
writer.write("\n".join(test_preds))
return results
def _mp_fn(index):
# For xla_spawn (TPUs)
......
This diff is collapsed.
......@@ -49,8 +49,9 @@ if SRC_DIRS is not None:
import run_mlm
import run_ner
import run_qa as run_squad
import run_seq2seq
import run_summarization
import run_swag
import run_translation
logging.basicConfig(level=logging.DEBUG)
......@@ -277,15 +278,14 @@ class ExamplesTests(TestCasePlus):
self.assertGreaterEqual(len(result[0]), 10)
@slow
def test_run_seq2seq_summarization(self):
def test_run_summarization(self):
stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)
tmp_dir = self.get_auto_remove_tmp_dir()
testargs = f"""
run_seq2seq.py
run_summarization.py
--model_name_or_path t5-small
--task summarization
--train_file tests/fixtures/tests_samples/xsum/sample.json
--validation_file tests/fixtures/tests_samples/xsum/sample.json
--output_dir {tmp_dir}
......@@ -301,7 +301,7 @@ class ExamplesTests(TestCasePlus):
""".split()
with patch.object(sys, "argv", testargs):
run_seq2seq.main()
run_summarization.main()
result = get_results(tmp_dir)
self.assertGreaterEqual(result["eval_rouge1"], 10)
self.assertGreaterEqual(result["eval_rouge2"], 2)
......@@ -309,15 +309,16 @@ class ExamplesTests(TestCasePlus):
self.assertGreaterEqual(result["eval_rougeLsum"], 7)
@slow
def test_run_seq2seq_translation(self):
def test_run_translation(self):
stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)
tmp_dir = self.get_auto_remove_tmp_dir()
testargs = f"""
run_seq2seq.py
run_translation.py
--model_name_or_path sshleifer/student_marian_en_ro_6_1
--task translation_en_to_ro
--source_lang en
--target_lang ro
--train_file tests/fixtures/tests_samples/wmt16/sample.json
--validation_file tests/fixtures/tests_samples/wmt16/sample.json
--output_dir {tmp_dir}
......@@ -335,6 +336,6 @@ class ExamplesTests(TestCasePlus):
""".split()
with patch.object(sys, "argv", testargs):
run_seq2seq.main()
run_translation.main()
result = get_results(tmp_dir)
self.assertGreaterEqual(result["eval_bleu"], 30)
......@@ -233,7 +233,6 @@ class TestDeepSpeed(TestCasePlus):
--group_by_length
--label_smoothing_factor 0.1
--adafactor
--task translation
--target_lang ro_RO
--source_lang en_XX
""".split()
......@@ -246,7 +245,7 @@ class TestDeepSpeed(TestCasePlus):
args = [x for x in args if x not in remove_args]
ds_args = f"--deepspeed {self.test_file_dir_str}/ds_config.json".split()
script = [f"{self.examples_dir_str}/seq2seq/run_seq2seq.py"]
script = [f"{self.examples_dir_str}/seq2seq/run_translation.py"]
num_gpus = get_gpu_count() if distributed else 1
launcher = f"deepspeed --num_gpus {num_gpus}".split()
......
......@@ -35,7 +35,7 @@ from transformers.trainer_utils import set_seed
bindir = os.path.abspath(os.path.dirname(__file__))
sys.path.append(f"{bindir}/../../seq2seq")
from run_seq2seq import main # noqa
from run_translation import main # noqa
set_seed(42)
......@@ -209,7 +209,6 @@ class TestTrainerExt(TestCasePlus):
--group_by_length
--label_smoothing_factor 0.1
--adafactor
--task translation
--target_lang ro_RO
--source_lang en_XX
"""
......@@ -226,12 +225,12 @@ class TestTrainerExt(TestCasePlus):
distributed_args = f"""
-m torch.distributed.launch
--nproc_per_node={n_gpu}
{self.examples_dir_str}/seq2seq/run_seq2seq.py
{self.examples_dir_str}/seq2seq/run_translation.py
""".split()
cmd = [sys.executable] + distributed_args + args
execute_subprocess_async(cmd, env=self.get_env())
else:
testargs = ["run_seq2seq.py"] + args
testargs = ["run_translation.py"] + args
with patch.object(sys, "argv", testargs):
main()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment