Unverified Commit eadd870b authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

[seq2seq] make it easier to run the scripts (#7274)

parent 8d3bb781
...@@ -100,7 +100,7 @@ All finetuning bash scripts call finetune.py (or distillation.py) with reasonabl ...@@ -100,7 +100,7 @@ All finetuning bash scripts call finetune.py (or distillation.py) with reasonabl
To see all the possible command line options, run: To see all the possible command line options, run:
```bash ```bash
./finetune.sh --help # this calls python finetune.py --help ./finetune.py --help
``` ```
### Finetuning Training Params ### Finetuning Training Params
...@@ -197,7 +197,7 @@ If 'translation' is in your task name, the computed metric will be BLEU. Otherwi ...@@ -197,7 +197,7 @@ If 'translation' is in your task name, the computed metric will be BLEU. Otherwi
For t5, you need to specify --task translation_{src}_to_{tgt} as follows: For t5, you need to specify --task translation_{src}_to_{tgt} as follows:
```bash ```bash
export DATA_DIR=wmt_en_ro export DATA_DIR=wmt_en_ro
python run_eval.py t5-base \ ./run_eval.py t5-base \
$DATA_DIR/val.source t5_val_generations.txt \ $DATA_DIR/val.source t5_val_generations.txt \
--reference_path $DATA_DIR/val.target \ --reference_path $DATA_DIR/val.target \
--score_path enro_bleu.json \ --score_path enro_bleu.json \
...@@ -211,7 +211,7 @@ python run_eval.py t5-base \ ...@@ -211,7 +211,7 @@ python run_eval.py t5-base \
This command works for MBART, although the BLEU score is suspiciously low. This command works for MBART, although the BLEU score is suspiciously low.
```bash ```bash
export DATA_DIR=wmt_en_ro export DATA_DIR=wmt_en_ro
python run_eval.py facebook/mbart-large-en-ro $DATA_DIR/val.source mbart_val_generations.txt \ ./run_eval.py facebook/mbart-large-en-ro $DATA_DIR/val.source mbart_val_generations.txt \
--reference_path $DATA_DIR/val.target \ --reference_path $DATA_DIR/val.target \
--score_path enro_bleu.json \ --score_path enro_bleu.json \
--task translation \ --task translation \
...@@ -224,7 +224,7 @@ python run_eval.py facebook/mbart-large-en-ro $DATA_DIR/val.source mbart_val_gen ...@@ -224,7 +224,7 @@ python run_eval.py facebook/mbart-large-en-ro $DATA_DIR/val.source mbart_val_gen
Summarization (xsum will be very similar): Summarization (xsum will be very similar):
```bash ```bash
export DATA_DIR=cnn_dm export DATA_DIR=cnn_dm
python run_eval.py sshleifer/distilbart-cnn-12-6 $DATA_DIR/val.source dbart_val_generations.txt \ ./run_eval.py sshleifer/distilbart-cnn-12-6 $DATA_DIR/val.source dbart_val_generations.txt \
--reference_path $DATA_DIR/val.target \ --reference_path $DATA_DIR/val.target \
--score_path cnn_rouge.json \ --score_path cnn_rouge.json \
--task summarization \ --task summarization \
...@@ -238,7 +238,7 @@ python run_eval.py sshleifer/distilbart-cnn-12-6 $DATA_DIR/val.source dbart_val_ ...@@ -238,7 +238,7 @@ python run_eval.py sshleifer/distilbart-cnn-12-6 $DATA_DIR/val.source dbart_val_
### Multi-GPU Evalulation ### Multi-GPU Evalulation
here is a command to run xsum evaluation on 8 GPUS. It is more than linearly faster than run_eval.py in some cases here is a command to run xsum evaluation on 8 GPUS. It is more than linearly faster than run_eval.py in some cases
because it uses SortishSampler to minimize padding. You can also use it on 1 GPU. `data_dir` must have because it uses SortishSampler to minimize padding. You can also use it on 1 GPU. `data_dir` must have
`{type_path}.source` and `{type_path}.target`. Run `python run_distributed_eval.py --help` for all clargs. `{type_path}.source` and `{type_path}.target`. Run `./run_distributed_eval.py --help` for all clargs.
```bash ```bash
python -m torch.distributed.launch --nproc_per_node=8 run_distributed_eval.py \ python -m torch.distributed.launch --nproc_per_node=8 run_distributed_eval.py \
...@@ -371,11 +371,11 @@ This feature can only be used: ...@@ -371,11 +371,11 @@ This feature can only be used:
- with fairseq installed - with fairseq installed
- on 1 GPU - on 1 GPU
- without sortish sampler - without sortish sampler
- after calling `python save_len_file.py $tok $data_dir` - after calling `./save_len_file.py $tok $data_dir`
For example, For example,
```bash ```bash
python save_len_file.py Helsinki-NLP/opus-mt-en-ro wmt_en_ro ./save_len_file.py Helsinki-NLP/opus-mt-en-ro wmt_en_ro
./dynamic_bs_example.sh --max_tokens_per_batch=2000 --output_dir benchmark_dynamic_bs ./dynamic_bs_example.sh --max_tokens_per_batch=2000 --output_dir benchmark_dynamic_bs
``` ```
splits `wmt_en_ro/train` into 11,197 uneven lengthed batches and can finish 1 epoch in 8 minutes on a v100. splits `wmt_en_ro/train` into 11,197 uneven lengthed batches and can finish 1 epoch in 8 minutes on a v100.
......
#!/usr/bin/env python
from typing import Union from typing import Union
import fire import fire
......
#!/usr/bin/env python
import os import os
from pathlib import Path from pathlib import Path
from typing import Dict, List from typing import Dict, List
......
#!/usr/bin/env python
import argparse import argparse
import gc import gc
import os import os
import sys
import warnings import warnings
from pathlib import Path from pathlib import Path
from typing import List from typing import List
...@@ -13,7 +16,6 @@ from torch.nn import functional as F ...@@ -13,7 +16,6 @@ from torch.nn import functional as F
from finetune import SummarizationModule, TranslationModule from finetune import SummarizationModule, TranslationModule
from finetune import main as ft_main from finetune import main as ft_main
from initialization_utils import copy_layers, init_student from initialization_utils import copy_layers, init_student
from lightning_base import generic_train
from transformers import AutoModelForSeq2SeqLM, MBartTokenizer, T5Config, T5ForConditionalGeneration from transformers import AutoModelForSeq2SeqLM, MBartTokenizer, T5Config, T5ForConditionalGeneration
from transformers.modeling_bart import shift_tokens_right from transformers.modeling_bart import shift_tokens_right
from utils import ( from utils import (
...@@ -27,6 +29,11 @@ from utils import ( ...@@ -27,6 +29,11 @@ from utils import (
) )
# need the parent dir module
sys.path.insert(2, str(Path(__file__).resolve().parents[1]))
from lightning_base import generic_train # noqa
class BartSummarizationDistiller(SummarizationModule): class BartSummarizationDistiller(SummarizationModule):
"""Supports Bart, Pegasus and other models that inherit from Bart.""" """Supports Bart, Pegasus and other models that inherit from Bart."""
......
#!/usr/bin/env python
from pathlib import Path from pathlib import Path
import fire import fire
......
#!/usr/bin/env python
import argparse import argparse
import glob import glob
import logging import logging
import os import os
import sys
import time import time
from collections import defaultdict from collections import defaultdict
from pathlib import Path from pathlib import Path
...@@ -13,7 +16,6 @@ import torch ...@@ -13,7 +16,6 @@ import torch
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback, get_early_stopping_callback from callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback, get_early_stopping_callback
from lightning_base import BaseTransformer, add_generic_args, generic_train
from transformers import MBartTokenizer, T5ForConditionalGeneration from transformers import MBartTokenizer, T5ForConditionalGeneration
from transformers.modeling_bart import shift_tokens_right from transformers.modeling_bart import shift_tokens_right
from utils import ( from utils import (
...@@ -34,6 +36,11 @@ from utils import ( ...@@ -34,6 +36,11 @@ from utils import (
) )
# need the parent dir module
sys.path.insert(2, str(Path(__file__).resolve().parents[1]))
from lightning_base import BaseTransformer, add_generic_args, generic_train # noqa
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
......
# Add parent directory to python path to access lightning_base.py
export PYTHONPATH="../":"${PYTHONPATH}"
# the proper usage is documented in the README, you need to specify data_dir, output_dir and model_name_or_path # the proper usage is documented in the README, you need to specify data_dir, output_dir and model_name_or_path
# run ./finetune.sh --help to see all the possible options # run ./finetune.sh --help to see all the possible options
python finetune.py \ python finetune.py \
......
#!/usr/bin/env python
from pathlib import Path from pathlib import Path
import fire import fire
......
#!/usr/bin/env python
"""Fill examples with bitext up to max_tokens without breaking up examples. """Fill examples with bitext up to max_tokens without breaking up examples.
[['I went', 'yo fui'], [['I went', 'yo fui'],
['to the store', 'a la tienda'] ['to the store', 'a la tienda']
......
#!/usr/bin/env python
import argparse import argparse
import shutil import shutil
import time import time
......
#!/usr/bin/env python
import argparse import argparse
import datetime import datetime
import json import json
......
#!/usr/bin/env python
import argparse import argparse
import itertools import itertools
import operator import operator
......
#!/usr/bin/env python
import fire import fire
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from tqdm import tqdm from tqdm import tqdm
from transformers import AutoTokenizer from transformers import AutoTokenizer
from utils import Seq2SeqDataset, pickle_save
try:
from .utils import Seq2SeqDataset, pickle_save
except ImportError:
from utils import Seq2SeqDataset, pickle_save
def save_len_file( def save_len_file(
......
#!/usr/bin/env python
import argparse import argparse
import os import os
import sys import sys
......
...@@ -6,14 +6,13 @@ import numpy as np ...@@ -6,14 +6,13 @@ import numpy as np
import pytest import pytest
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from pack_dataset import pack_data_dir
from save_len_file import save_len_file
from test_seq2seq_examples import ARTICLES, BART_TINY, MARIAN_TINY, MBART_TINY, SUMMARIES, T5_TINY, make_test_data_dir
from transformers import AutoTokenizer from transformers import AutoTokenizer
from transformers.modeling_bart import shift_tokens_right from transformers.modeling_bart import shift_tokens_right
from transformers.testing_utils import slow from transformers.testing_utils import slow
from utils import FAIRSEQ_AVAILABLE, DistributedSortishSampler, LegacySeq2SeqDataset, Seq2SeqDataset
from .pack_dataset import pack_data_dir
from .save_len_file import save_len_file
from .test_seq2seq_examples import ARTICLES, BART_TINY, MARIAN_TINY, MBART_TINY, SUMMARIES, T5_TINY, make_test_data_dir
from .utils import FAIRSEQ_AVAILABLE, DistributedSortishSampler, LegacySeq2SeqDataset, Seq2SeqDataset
BERT_BASE_CASED = "bert-base-cased" BERT_BASE_CASED = "bert-base-cased"
......
...@@ -14,19 +14,13 @@ ...@@ -14,19 +14,13 @@
# limitations under the License. # limitations under the License.
import io import io
import unittest
try:
from .utils import calculate_bleu
except ImportError:
from utils import calculate_bleu
import json import json
import unittest
from parameterized import parameterized from parameterized import parameterized
from transformers import FSMTForConditionalGeneration, FSMTTokenizer from transformers import FSMTForConditionalGeneration, FSMTTokenizer
from transformers.testing_utils import get_tests_dir, require_torch, slow, torch_device from transformers.testing_utils import get_tests_dir, require_torch, slow, torch_device
from utils import calculate_bleu
filename = get_tests_dir() + "/test_data/fsmt/fsmt_val_data.json" filename = get_tests_dir() + "/test_data/fsmt/fsmt_val_data.json"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment