Unverified Commit eadd870b authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

[seq2seq] make it easier to run the scripts (#7274)

parent 8d3bb781
......@@ -100,7 +100,7 @@ All finetuning bash scripts call finetune.py (or distillation.py) with reasonabl
To see all the possible command line options, run:
```bash
./finetune.sh --help # this calls python finetune.py --help
./finetune.py --help
```
### Finetuning Training Params
......@@ -197,7 +197,7 @@ If 'translation' is in your task name, the computed metric will be BLEU. Otherwi
For t5, you need to specify --task translation_{src}_to_{tgt} as follows:
```bash
export DATA_DIR=wmt_en_ro
python run_eval.py t5-base \
./run_eval.py t5-base \
$DATA_DIR/val.source t5_val_generations.txt \
--reference_path $DATA_DIR/val.target \
--score_path enro_bleu.json \
......@@ -211,7 +211,7 @@ python run_eval.py t5-base \
This command works for MBART, although the BLEU score is suspiciously low.
```bash
export DATA_DIR=wmt_en_ro
python run_eval.py facebook/mbart-large-en-ro $DATA_DIR/val.source mbart_val_generations.txt \
./run_eval.py facebook/mbart-large-en-ro $DATA_DIR/val.source mbart_val_generations.txt \
--reference_path $DATA_DIR/val.target \
--score_path enro_bleu.json \
--task translation \
......@@ -224,7 +224,7 @@ python run_eval.py facebook/mbart-large-en-ro $DATA_DIR/val.source mbart_val_gen
Summarization (xsum will be very similar):
```bash
export DATA_DIR=cnn_dm
python run_eval.py sshleifer/distilbart-cnn-12-6 $DATA_DIR/val.source dbart_val_generations.txt \
./run_eval.py sshleifer/distilbart-cnn-12-6 $DATA_DIR/val.source dbart_val_generations.txt \
--reference_path $DATA_DIR/val.target \
--score_path cnn_rouge.json \
--task summarization \
......@@ -238,7 +238,7 @@ python run_eval.py sshleifer/distilbart-cnn-12-6 $DATA_DIR/val.source dbart_val_
### Multi-GPU Evalulation
here is a command to run xsum evaluation on 8 GPUS. It is more than linearly faster than run_eval.py in some cases
because it uses SortishSampler to minimize padding. You can also use it on 1 GPU. `data_dir` must have
`{type_path}.source` and `{type_path}.target`. Run `python run_distributed_eval.py --help` for all clargs.
`{type_path}.source` and `{type_path}.target`. Run `./run_distributed_eval.py --help` for all clargs.
```bash
python -m torch.distributed.launch --nproc_per_node=8 run_distributed_eval.py \
......@@ -371,11 +371,11 @@ This feature can only be used:
- with fairseq installed
- on 1 GPU
- without sortish sampler
- after calling `python save_len_file.py $tok $data_dir`
- after calling `./save_len_file.py $tok $data_dir`
For example,
```bash
python save_len_file.py Helsinki-NLP/opus-mt-en-ro wmt_en_ro
./save_len_file.py Helsinki-NLP/opus-mt-en-ro wmt_en_ro
./dynamic_bs_example.sh --max_tokens_per_batch=2000 --output_dir benchmark_dynamic_bs
```
splits `wmt_en_ro/train` into 11,197 uneven lengthed batches and can finish 1 epoch in 8 minutes on a v100.
......
#!/usr/bin/env python
from typing import Union
import fire
......
#!/usr/bin/env python
import os
from pathlib import Path
from typing import Dict, List
......
#!/usr/bin/env python
import argparse
import gc
import os
import sys
import warnings
from pathlib import Path
from typing import List
......@@ -13,7 +16,6 @@ from torch.nn import functional as F
from finetune import SummarizationModule, TranslationModule
from finetune import main as ft_main
from initialization_utils import copy_layers, init_student
from lightning_base import generic_train
from transformers import AutoModelForSeq2SeqLM, MBartTokenizer, T5Config, T5ForConditionalGeneration
from transformers.modeling_bart import shift_tokens_right
from utils import (
......@@ -27,6 +29,11 @@ from utils import (
)
# need the parent dir module
sys.path.insert(2, str(Path(__file__).resolve().parents[1]))
from lightning_base import generic_train # noqa
class BartSummarizationDistiller(SummarizationModule):
"""Supports Bart, Pegasus and other models that inherit from Bart."""
......
#!/usr/bin/env python
from pathlib import Path
import fire
......
#!/usr/bin/env python
import argparse
import glob
import logging
import os
import sys
import time
from collections import defaultdict
from pathlib import Path
......@@ -13,7 +16,6 @@ import torch
from torch.utils.data import DataLoader
from callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback, get_early_stopping_callback
from lightning_base import BaseTransformer, add_generic_args, generic_train
from transformers import MBartTokenizer, T5ForConditionalGeneration
from transformers.modeling_bart import shift_tokens_right
from utils import (
......@@ -34,6 +36,11 @@ from utils import (
)
# need the parent dir module
sys.path.insert(2, str(Path(__file__).resolve().parents[1]))
from lightning_base import BaseTransformer, add_generic_args, generic_train # noqa
logger = logging.getLogger(__name__)
......
# Add parent directory to python path to access lightning_base.py
export PYTHONPATH="../":"${PYTHONPATH}"
# the proper usage is documented in the README, you need to specify data_dir, output_dir and model_name_or_path
# run ./finetune.sh --help to see all the possible options
python finetune.py \
......
#!/usr/bin/env python
from pathlib import Path
import fire
......
#!/usr/bin/env python
"""Fill examples with bitext up to max_tokens without breaking up examples.
[['I went', 'yo fui'],
['to the store', 'a la tienda']
......
#!/usr/bin/env python
import argparse
import shutil
import time
......
#!/usr/bin/env python
import argparse
import datetime
import json
......
#!/usr/bin/env python
import argparse
import itertools
import operator
......
#!/usr/bin/env python
import fire
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import AutoTokenizer
try:
from .utils import Seq2SeqDataset, pickle_save
except ImportError:
from utils import Seq2SeqDataset, pickle_save
from utils import Seq2SeqDataset, pickle_save
def save_len_file(
......
#!/usr/bin/env python
import argparse
import os
import sys
......
......@@ -6,14 +6,13 @@ import numpy as np
import pytest
from torch.utils.data import DataLoader
from pack_dataset import pack_data_dir
from save_len_file import save_len_file
from test_seq2seq_examples import ARTICLES, BART_TINY, MARIAN_TINY, MBART_TINY, SUMMARIES, T5_TINY, make_test_data_dir
from transformers import AutoTokenizer
from transformers.modeling_bart import shift_tokens_right
from transformers.testing_utils import slow
from .pack_dataset import pack_data_dir
from .save_len_file import save_len_file
from .test_seq2seq_examples import ARTICLES, BART_TINY, MARIAN_TINY, MBART_TINY, SUMMARIES, T5_TINY, make_test_data_dir
from .utils import FAIRSEQ_AVAILABLE, DistributedSortishSampler, LegacySeq2SeqDataset, Seq2SeqDataset
from utils import FAIRSEQ_AVAILABLE, DistributedSortishSampler, LegacySeq2SeqDataset, Seq2SeqDataset
BERT_BASE_CASED = "bert-base-cased"
......
......@@ -14,19 +14,13 @@
# limitations under the License.
import io
import unittest
try:
from .utils import calculate_bleu
except ImportError:
from utils import calculate_bleu
import json
import unittest
from parameterized import parameterized
from transformers import FSMTForConditionalGeneration, FSMTTokenizer
from transformers.testing_utils import get_tests_dir, require_torch, slow, torch_device
from utils import calculate_bleu
filename = get_tests_dir() + "/test_data/fsmt/fsmt_val_data.json"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment