Unverified Commit 40457bce authored by Sam Shleifer's avatar Sam Shleifer Committed by GitHub
Browse files

examples/seq2seq supports translation (#5202)

parent d12ceb48
This directory contains examples for finetuning and evaluating transformers on summarization and translation tasks.
Summarization support is more mature than translation support.
Please tag @sshleifer with any issues/unexpected behaviors, or send a PR!
For `bertabs` instructions, see `bertabs/README.md`.
### Data
CNN/DailyMail data
```bash
cd examples/seq2seq
wget https://s3.amazonaws.com/datasets.huggingface.co/summarization/cnn_dm.tgz
tar -xzvf cnn_dm.tgz
export CNN_DIR=${PWD}/cnn_dm
```
this should make a directory called cnn_dm/ with files like `test.source`.
To use your own data, copy that files format. Each article to be summarized is on its own line.
XSUM Data:
```bash
cd examples/seq2seq
wget https://s3.amazonaws.com/datasets.huggingface.co/summarization/xsum.tar.gz
tar -xzvf xsum.tar.gz
export XSUM_DIR=${PWD}/xsum
```
WMT16 English-Romanian Translation Data:
```bash
cd examples/seq2seq
wget https://s3.amazonaws.com/datasets.huggingface.co/translation/wmt_en_ro.tar.gz
tar -xzvf wmt_en_ro.tar.gz
export ENRO_DIR=${PWD}/wmt_en_ro
```
If you are using your own data, it must be formatted as one directory with 6 files: train.source, train.target, val.source, val.target, test.source, test.target.
The `.source` files are the input, the `.target` files are the desired output.
### Evaluation
To create summaries for each article in dataset, run:
```bash
python run_eval.py <path_to_test.source> test_generations.txt <model-name> --score_path rouge_scores.txt
```
The default batch size, 4, fits in 16GB GPU memory, but may need to be adjusted to fit your system.
### Summarization Finetuning
Run/modify `finetune.sh`
The following command should work on a 16GB GPU:
```bash
./finetune.sh \
--data_dir $XSUM_DIR \
--train_batch_size=1 \
--eval_batch_size=1 \
--output_dir=xsum_results \
--num_train_epochs 1 \
--model_name_or_path facebook/bart-large
```
*Note*: The following tips mostly apply to summarization finetuning.
Tips:
- 1 epoch at batch size 1 for bart-large takes 24 hours and requires 13GB GPU RAM with fp16 on an NVIDIA-V100.
- try `bart-base`, `--freeze_encoder` or `--freeze_embeds` for faster training/larger batch size. (3hr/epoch with bs=8, see the "xsum_shared_task" command below)
- `fp16_opt_level=O1` (the default works best).
- If you are finetuning on your own dataset, start from `distilbart-cnn-12-6` if you want long summaries and `distilbart-xsum-12-6` if you want short summaries.
(It rarely makes sense to start from `bart-large` unless you are a researching finetuning methods).
- In addition to the pytorch-lightning .ckpt checkpoint, a transformers checkpoint will be saved.
Load it with `BartForConditionalGeneration.from_pretrained(f'{output_dir}/best_tfmr)`.
- At the moment, `--do_predict` does not work in a multi-gpu setting. You need to use `evaluate_checkpoint` or the `run_eval.py` code.
- If you want to run experiments on improving the summarization finetuning process, try the XSUM Shared Task (below). It's faster to train than CNNDM because the summaries are shorter.
- For CNN/DailyMail, the default `val_max_target_length` and `test_max_target_length` will truncate the ground truth labels, resulting in slightly higher rouge scores. To get accurate rouge scores, you should rerun calculate_rouge on the `{output_dir}/test_generations.txt` file saved by `trainer.test()`
- `--max_target_length=60 --val_max_target_length=60 --test_max_target_length=100 ` is a reasonable setting for XSUM.
- `wandb` can be used by specifying `--logger wandb_shared` or `--logger wandb`. It is useful for reproducibility.
- This warning can be safely ignored:
> "Some weights of BartForConditionalGeneration were not initialized from the model checkpoint at facebook/bart-large-xsum and are newly initialized: ['final_logits_bias']"
- Both finetuning and eval are 30% faster with `--fp16`. For that you need to [install apex](https://github.com/NVIDIA/apex#quick-start).
#### Finetuning Outputs
As you train, `output_dir` will be filled with files, that look kind of like this (comments are mine).
Some of them are metrics, some of them are checkpoints, some of them are metadata. Here is a quick tour:
```bash
output_dir
├── best_tfmr # this is a huggingface checkpoint generated by save_pretrained. It is the same model as the PL .ckpt file below
│   ├── config.json
│   ├── merges.txt
│   ├── pytorch_model.bin
│   ├── special_tokens_map.json
│   ├── tokenizer_config.json
│   └── vocab.json
├── git_log.json # repo, branch, and commit hash
├── val_avg_rouge2=0.1984-step_count=11.ckpt # this is a pytorch lightning checkpoint associated with the best val score.
├── metrics.json # new validation metrics will continually be appended to this
├── student # this is a huggingface checkpoint generated by SummarizationDistiller. It is the student before it gets finetuned.
│   ├── config.json
│   └── pytorch_model.bin
├── test_generations.txt
# ^^ are the summaries or translations produced by your best checkpoint on the test data. Populated when training is done
├── test_results.txt # a convenience file with the test set metrics. This data is also in metrics.json['test']
├── hparams.pkl # the command line args passed after some light preprocessing. Should be saved fairly quickly.
```
After training, you can recover the best checkpoint by running
```python
from transformers import AutoModelForSeq2SeqLM
model = AutoModelForSeq2SeqLM.from_pretrained(f'{output_dir}/best_tfmr')
```
### XSUM Shared Task
Compare XSUM results with others by using `--logger wandb_shared`. This requires `wandb` registration.
Here is an example command, but you can do whatever you want. Hopefully this will make debugging and collaboration easier!
```bash
./finetune.sh \
--data_dir $XSUM_DIR \
--output_dir xsum_frozen_embs \
--model_name_or_path facebook/bart-large \
--logger wandb_shared \
--train_batch_size 16 --eval_batch_size 16 --freeze_embeds --freeze_encoder \
--num_train_epochs 6 \
--max_target_length=60 --val_max_target_length=60 --test_max_target_length=100
```
Results can be viewed [here](https://app.wandb.ai/sshleifer/hf_xsum?workspace=user-)
### Distilbart
#### No Teacher Distillation
To run the simpler distilbart-cnn style distillation all you need is data, a GPU, and a properly initialized student.
You don't even need `distillation.py`.
Some [un-finetuned students](https://huggingface.co/models?search=sshleifer%2Fstudent) are available for replication purposes.
They are initialized by copying layers from the associated `bart-large-{cnn|xsum}` teacher using `--init_strategy alternate`. (You can read about that in `initialization_utils.py`)
The command that produced `sshleifer/distilbart-cnn-12-6` is
```bash
./train_distilbart_cnn.sh
```
runtime: 6H on NVIDIA RTX 24GB GPU
*Note*: You can get the same simple distillation logic by using `./run_distiller.sh --no_teacher` followed by identical arguments as the ones in `train_distilbart_cnn.sh`.
If you are using `wandb` and comparing the two distillation methods, using this entry point will make your logs consistent,
because you will have the same hyperparameters logged in every run.
#### With a teacher
*Note* only BART variants are supported
In this method, we use try to enforce that the student and teacher produce similar encoder_outputs, logits, and hidden_states using `BartSummarizationDistiller`.
This is how `sshleifer/distilbart-xsum*` checkpoints were produced.
The command that produced `sshleifer/distilbart-xsum-12-6` is:
```bash
./train_distilbart_xsum.sh
```
runtime: 13H on V-100 16GB GPU.
### Contributing
- follow the standard contributing guidelines and code of conduct.
- add tests to `test_seq2seq_examples.py`
- To run only the seq2seq tests, you must be in the root of the repository and run:
```bash
pytest examples/seq2seq/
```
...@@ -12,7 +12,7 @@ The model is loaded with the pre-trained weights for the abstractive summarizati ...@@ -12,7 +12,7 @@ The model is loaded with the pre-trained weights for the abstractive summarizati
git clone https://github.com/huggingface/transformers && cd transformers git clone https://github.com/huggingface/transformers && cd transformers
pip install . pip install .
pip install nltk py-rouge pip install nltk py-rouge
cd examples/summarization cd examples/seq2seq/bertabs
``` ```
## Reproduce the authors' ROUGE score ## Reproduce the authors' ROUGE score
......
...@@ -32,9 +32,12 @@ class Seq2SeqLoggingCallback(pl.Callback): ...@@ -32,9 +32,12 @@ class Seq2SeqLoggingCallback(pl.Callback):
results_file = od / "test_results.txt" results_file = od / "test_results.txt"
generations_file = od / "test_generations.txt" generations_file = od / "test_generations.txt"
else: else:
results_file = od / f"{type_path}_results_{trainer.global_step:05d}.txt" # this never gets hit. I prefer not to save intermediate generations, and results are in metrics.json
generations_file = od / f"{type_path}_generations_{trainer.global_step:05d}.txt" # If people want this it will be easy enough to add back.
results_file = od / f"{type_path}_results/{trainer.global_step:05d}.txt"
generations_file = od / f"{type_path}_generations/{trainer.global_step:05d}.txt"
results_file.parent.mkdir(exist_ok=True)
generations_file.parent.mkdir(exist_ok=True)
with open(results_file, "a+") as writer: with open(results_file, "a+") as writer:
for key in sorted(metrics): for key in sorted(metrics):
if key in ["log", "progress_bar", "preds"]: if key in ["log", "progress_bar", "preds"]:
...@@ -63,20 +66,25 @@ class Seq2SeqLoggingCallback(pl.Callback): ...@@ -63,20 +66,25 @@ class Seq2SeqLoggingCallback(pl.Callback):
# mp stands for million parameters # mp stands for million parameters
trainer.logger.log_metrics({"n_params": npars, "mp": npars / 1e6, "grad_mp": n_trainable_pars / 1e6}) trainer.logger.log_metrics({"n_params": npars, "mp": npars / 1e6, "grad_mp": n_trainable_pars / 1e6})
@rank_zero_only
def on_validation_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
return self._write_logs(trainer, pl_module, "val")
@rank_zero_only @rank_zero_only
def on_test_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule): def on_test_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
return self._write_logs(trainer, pl_module, "test") return self._write_logs(trainer, pl_module, "test")
def get_rouge2_checkpoint_callback(output_dir): def get_checkpoint_callback(output_dir, metric):
"""Saves the best model by validation ROUGE2 score.""" """Saves the best model by validation ROUGE2 score."""
if metric == "rouge2":
exp = "{val_avg_rouge2:.4f}-{step_count}"
elif metric == "bleu":
exp = "{val_avg_bleu:.4f}-{step_count}"
else:
raise NotImplementedError(
f"seq2seq callbacks only support rouge2 and bleu, got {metric}, You can make your own by adding to this function."
)
checkpoint_callback = ModelCheckpoint( checkpoint_callback = ModelCheckpoint(
filepath=os.path.join(output_dir, "{val_avg_rouge2:.4f}-{step_count}"), filepath=os.path.join(output_dir, exp),
monitor="val_rouge", monitor=f"val_{metric}",
mode="max", mode="max",
save_top_k=1, save_top_k=1,
period=0, # maybe save a checkpoint every time val is run, not just end of epoch. period=0, # maybe save a checkpoint every time val is run, not just end of epoch.
......
...@@ -39,13 +39,12 @@ except ImportError: ...@@ -39,13 +39,12 @@ except ImportError:
) )
class SummarizationDistiller(SummarizationModule): class BartSummarizationDistiller(SummarizationModule):
loss_names = ["loss", "ce_loss", "mlm_loss", "enc_mse_loss", "hid_loss_enc", "hid_loss_dec"] loss_names = ["loss", "ce_loss", "mlm_loss", "enc_mse_loss", "hid_loss_enc", "hid_loss_dec"]
def __init__(self, hparams): def __init__(self, hparams):
assert Path(hparams.data_dir).exists() assert Path(hparams.data_dir).exists()
student, student_cfg, teacher = self.pre_init(hparams)
d_layers_to_copy, student, student_cfg, teacher = self.pre_init(hparams)
super().__init__(hparams, model=student, config=student_cfg) super().__init__(hparams, model=student, config=student_cfg)
self.teacher = teacher self.teacher = teacher
...@@ -73,12 +72,15 @@ class SummarizationDistiller(SummarizationModule): ...@@ -73,12 +72,15 @@ class SummarizationDistiller(SummarizationModule):
del self.teacher.model.encoder del self.teacher.model.encoder
def pre_init(self, hparams): def pre_init(self, hparams):
# Dump empty student model at a path, then call from_pretrained on it self.output_dir = Path(hparams.output_dir)
self.output_dir.mkdir(exist_ok=True)
teacher = BartForConditionalGeneration.from_pretrained(hparams.teacher).eval() teacher = BartForConditionalGeneration.from_pretrained(hparams.teacher).eval()
student_updates = { student_updates = {
"decoder_layers": hparams.student_decoder_layers, "decoder_layers": hparams.student_decoder_layers,
"encoder_layers": hparams.student_encoder_layers, "encoder_layers": hparams.student_encoder_layers,
} }
if hparams.length_penalty != -1:
student_updates["length_penalty"] = hparams.length_penalty
d_layers_to_copy = get_layers_to_copy(student_updates["decoder_layers"], teacher.config.decoder_layers) d_layers_to_copy = get_layers_to_copy(student_updates["decoder_layers"], teacher.config.decoder_layers)
e_layers_to_copy: List = get_layers_to_copy(student_updates["encoder_layers"], teacher.config.encoder_layers) e_layers_to_copy: List = get_layers_to_copy(student_updates["encoder_layers"], teacher.config.encoder_layers)
hparams.d_layer_to_copy = d_layers_to_copy hparams.d_layer_to_copy = d_layers_to_copy
...@@ -89,9 +91,13 @@ class SummarizationDistiller(SummarizationModule): ...@@ -89,9 +91,13 @@ class SummarizationDistiller(SummarizationModule):
student_cfg = BartConfig(**kw) student_cfg = BartConfig(**kw)
student = BartForConditionalGeneration(student_cfg) student = BartForConditionalGeneration(student_cfg)
student, _ = init_student(student, teacher) student, _ = init_student(student, teacher)
save_dir = self.output_dir.joinpath("student")
save_dir.mkdir(exist_ok=True)
self.copy_to_student(d_layers_to_copy, e_layers_to_copy, hparams, student, teacher) self.copy_to_student(d_layers_to_copy, e_layers_to_copy, hparams, student, teacher)
Path(hparams.output_dir).mkdir(exist_ok=True) student.save_pretrained(save_dir)
return d_layers_to_copy, student, student_cfg, teacher hparams.model_name_or_path = str(save_dir)
return student, student_cfg, teacher
def copy_to_student(self, d_layers_to_copy, e_layers_to_copy, hparams, student, teacher): def copy_to_student(self, d_layers_to_copy, e_layers_to_copy, hparams, student, teacher):
if teacher.config.model_type == "t5": if teacher.config.model_type == "t5":
...@@ -154,7 +160,6 @@ class SummarizationDistiller(SummarizationModule): ...@@ -154,7 +160,6 @@ class SummarizationDistiller(SummarizationModule):
def configure_optimizers(self): def configure_optimizers(self):
"Prepare optimizer and schedule (linear warmup and decay)" "Prepare optimizer and schedule (linear warmup and decay)"
model = self.model model = self.model
no_decay = ["bias", "LayerNorm.weight"] no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [ optimizer_grouped_parameters = [
...@@ -180,18 +185,11 @@ class SummarizationDistiller(SummarizationModule): ...@@ -180,18 +185,11 @@ class SummarizationDistiller(SummarizationModule):
# parser.add_argument("--alpha_cos", default=0.0, type=float) # parser.add_argument("--alpha_cos", default=0.0, type=float)
parser.add_argument("--alpha_encoder_loss", default=0.0, type=float) parser.add_argument("--alpha_encoder_loss", default=0.0, type=float)
parser.add_argument("--alpha_hid", default=0.0, type=float, required=False) parser.add_argument("--alpha_hid", default=0.0, type=float, required=False)
parser.add_argument( parser.add_argument("--student_decoder_layers", default=12, type=int, required=False)
"--student_decoder_layers", default=12, type=int, required=False, parser.add_argument("--student_encoder_layers", default=12, type=int, required=False)
) parser.add_argument("--no_teacher", action="store_true", default=False)
parser.add_argument( parser.add_argument("--length_penalty", type=float, default=-1)
"--student_encoder_layers", default=12, type=int, required=False,
)
parser.add_argument(
"--no_teacher", action="store_true", default=False,
)
parser.add_argument( # TODO: remove
"--enc_only", action="store_true", default=False,
)
return parser return parser
def _step(self, batch): def _step(self, batch):
...@@ -269,12 +267,14 @@ class SummarizationDistiller(SummarizationModule): ...@@ -269,12 +267,14 @@ class SummarizationDistiller(SummarizationModule):
return sum(hidden_losses) return sum(hidden_losses)
class T5SummarizationDistiller(SummarizationDistiller): class T5SummarizationDistiller(BartSummarizationDistiller):
def pre_init(self, hparams): def pre_init(self, hparams):
raise NotImplementedError("T5 Distillation does not work yet") raise NotImplementedError("T5 Distillation does not work yet")
self.output_dir = Path(hparams.output_dir)
self.output_dir.mkdir(exist_ok=True)
teacher = T5ForConditionalGeneration.from_pretrained(hparams.teacher) teacher = T5ForConditionalGeneration.from_pretrained(hparams.teacher)
n_layer = hparams.student_decoder_layers n_layer = hparams.student_decoder_layers
assert n_layer == hparams.student_encoder_layers # TODO(SS): relax this assert n_layer == hparams.student_encoder_layers # TODO(SS): relax this constraint so that we can do 12-6.
d_layers_to_copy = get_layers_to_copy(n_layer, len(teacher.decoder.block)) d_layers_to_copy = get_layers_to_copy(n_layer, len(teacher.decoder.block))
e_layers_to_copy: List = get_layers_to_copy(n_layer, len(teacher.encoder.block)) e_layers_to_copy: List = get_layers_to_copy(n_layer, len(teacher.encoder.block))
student_updates = {"num_layers": n_layer} student_updates = {"num_layers": n_layer}
...@@ -291,8 +291,13 @@ class T5SummarizationDistiller(SummarizationDistiller): ...@@ -291,8 +291,13 @@ class T5SummarizationDistiller(SummarizationDistiller):
Path(hparams.output_dir).mkdir(exist_ok=True) Path(hparams.output_dir).mkdir(exist_ok=True)
task_specific_params = student.config.task_specific_params task_specific_params = student.config.task_specific_params
if task_specific_params is not None: if task_specific_params is not None:
student.config.update(task_specific_params.get("summarization", {})) student.config.update(task_specific_params.get("summarization", {})) # TODO: dont hardcode
return d_layers_to_copy, student, student_cfg, teacher save_dir = self.output_dir.joinpath("student")
save_dir.mkdir(exist_ok=True)
student.save_pretrained(save_dir)
hparams.model_name_or_path = str(save_dir)
return student, student_cfg, teacher
def freeze_embeds(self): def freeze_embeds(self):
freeze_params(self.model.shared) freeze_params(self.model.shared)
...@@ -386,7 +391,7 @@ def create_module(args): ...@@ -386,7 +391,7 @@ def create_module(args):
elif args.enc_only: elif args.enc_only:
raise ValueError("Deleted that") raise ValueError("Deleted that")
else: else:
module_cls = SummarizationDistiller module_cls = BartSummarizationDistiller
args.setup_cls: str = module_cls.__name__ args.setup_cls: str = module_cls.__name__
model = module_cls(args) model = module_cls(args)
return model return model
...@@ -418,18 +423,18 @@ def evaluate_checkpoint(ckpt_path: Path, dest_dir=None): ...@@ -418,18 +423,18 @@ def evaluate_checkpoint(ckpt_path: Path, dest_dir=None):
def get_layers_to_copy(n_to_get, tot): def get_layers_to_copy(n_to_get, tot):
all_layers = list(range(tot)) all_layers = list(range(tot))
if tot == 12: # Alternating for special cases if tot == 12: # Alternating for special cases
layers_to_copy = { # maps # layers in student -> which teacher layers to copy layers_to_copy = { # maps num layers in student -> which teacher layers to copy
6: [0, 2, 4, 7, 9, 11], 1: [0],
1: [11], 2: [0, 6],
3: [0, 6, 11], 3: [0, 6, 11],
2: [0, 11],
4: [0, 4, 8, 11], 4: [0, 4, 8, 11],
6: [0, 2, 4, 7, 9, 11],
9: [0, 1, 2, 4, 5, 7, 9, 10, 11], 9: [0, 1, 2, 4, 5, 7, 9, 10, 11],
12: all_layers, 12: all_layers,
} }
return layers_to_copy[n_to_get] return layers_to_copy[n_to_get]
else: else:
return all_layers[:n_to_get] return all_layers[:n_to_get] # TODO: better version on theseus-bart branch
def distill_main(args): def distill_main(args):
...@@ -443,7 +448,7 @@ def distill_main(args): ...@@ -443,7 +448,7 @@ def distill_main(args):
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser = SummarizationDistiller.add_model_specific_args(parser, os.getcwd()) parser = BartSummarizationDistiller.add_model_specific_args(parser, os.getcwd())
args = parser.parse_args() args = parser.parse_args()
distill_main(args) distill_main(args)
...@@ -3,6 +3,7 @@ import glob ...@@ -3,6 +3,7 @@ import glob
import logging import logging
import os import os
import time import time
from collections import defaultdict
from pathlib import Path from pathlib import Path
from typing import Dict, List, Tuple from typing import Dict, List, Tuple
...@@ -23,12 +24,14 @@ try: ...@@ -23,12 +24,14 @@ try:
flatten_list, flatten_list,
pickle_save, pickle_save,
save_git_info, save_git_info,
save_json,
freeze_params, freeze_params,
calculate_rouge, calculate_rouge,
get_git_info, get_git_info,
ROUGE_KEYS, ROUGE_KEYS,
calculate_bleu_score,
) )
from .callbacks import Seq2SeqLoggingCallback, get_rouge2_checkpoint_callback from .callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback
except ImportError: except ImportError:
from utils import ( from utils import (
use_task_specific_params, use_task_specific_params,
...@@ -37,12 +40,14 @@ except ImportError: ...@@ -37,12 +40,14 @@ except ImportError:
flatten_list, flatten_list,
pickle_save, pickle_save,
save_git_info, save_git_info,
save_json,
freeze_params, freeze_params,
calculate_rouge, calculate_rouge,
get_git_info, get_git_info,
ROUGE_KEYS, ROUGE_KEYS,
calculate_bleu_score,
) )
from callbacks import Seq2SeqLoggingCallback, get_rouge2_checkpoint_callback from callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -50,15 +55,18 @@ logger = logging.getLogger(__name__) ...@@ -50,15 +55,18 @@ logger = logging.getLogger(__name__)
class SummarizationModule(BaseTransformer): class SummarizationModule(BaseTransformer):
mode = "summarization" mode = "summarization"
loss_names = ["loss"] loss_names = ["loss"]
metric_names = ROUGE_KEYS
val_metric = "rouge2"
def __init__(self, hparams, **kwargs): def __init__(self, hparams, **kwargs):
super().__init__(hparams, num_labels=None, mode=self.mode, **kwargs) super().__init__(hparams, num_labels=None, mode=self.mode, **kwargs)
use_task_specific_params(self.model, "summarization") use_task_specific_params(self.model, "summarization")
save_git_info(self.hparams.output_dir) save_git_info(self.hparams.output_dir)
self.metrics_save_path = Path(self.output_dir) / "metrics.pkl" self.metrics_save_path = Path(self.output_dir) / "metrics.json"
self.hparams_save_path = Path(self.output_dir) / "hparams.pkl" self.hparams_save_path = Path(self.output_dir) / "hparams.pkl"
pickle_save(self.hparams, self.hparams_save_path)
self.step_count = 0 self.step_count = 0
self.metrics = {"train": [], "val": [], "test": []} self.metrics = defaultdict(list)
self.dataset_kwargs: dict = dict( self.dataset_kwargs: dict = dict(
data_dir=self.hparams.data_dir, data_dir=self.hparams.data_dir,
...@@ -89,12 +97,12 @@ class SummarizationModule(BaseTransformer): ...@@ -89,12 +97,12 @@ class SummarizationModule(BaseTransformer):
def freeze_embeds(self): def freeze_embeds(self):
"""Freeze token embeddings and positional embeddings for bart, just token embeddings for t5.""" """Freeze token embeddings and positional embeddings for bart, just token embeddings for t5."""
if self.model.config.model_type == "bart": try:
freeze_params(self.model.model.shared) freeze_params(self.model.model.shared)
for d in [self.model.model.encoder, self.model.model.decoder]: for d in [self.model.model.encoder, self.model.model.decoder]:
freeze_params(d.embed_positions) freeze_params(d.embed_positions)
freeze_params(d.embed_tokens) freeze_params(d.embed_tokens)
else: except AttributeError:
freeze_params(self.model.shared) freeze_params(self.model.shared)
for d in [self.model.encoder, self.model.decoder]: for d in [self.model.encoder, self.model.decoder]:
freeze_params(d.embed_tokens) freeze_params(d.embed_tokens)
...@@ -130,19 +138,22 @@ class SummarizationModule(BaseTransformer): ...@@ -130,19 +138,22 @@ class SummarizationModule(BaseTransformer):
self.step_count += 1 self.step_count += 1
losses = {k: torch.stack([x[k] for x in outputs]).mean() for k in self.loss_names} losses = {k: torch.stack([x[k] for x in outputs]).mean() for k in self.loss_names}
loss = losses["loss"] loss = losses["loss"]
rouges = {k: np.array([x[k] for x in outputs]).mean() for k in ROUGE_KEYS + ["gen_time", "summ_len"]} rouges = {k: np.array([x[k] for x in outputs]).mean() for k in self.metric_names + ["gen_time", "summ_len"]}
rouge_tensor: torch.FloatTensor = torch.tensor(rouges["rouge2"]).type_as(loss) rouge_tensor: torch.FloatTensor = torch.tensor(rouges[self.val_metric]).type_as(loss)
rouges.update({k: v.item() for k, v in losses.items()}) rouges.update({k: v.item() for k, v in losses.items()})
losses.update(rouges) losses.update(rouges)
metrics = {f"{prefix}_avg_{k}": x for k, x in losses.items()} metrics = {f"{prefix}_avg_{k}": x for k, x in losses.items()}
metrics["step_count"] = self.step_count metrics["step_count"] = self.step_count
self.save_metrics(metrics, prefix) # writes to self.metrics_save_path self.save_metrics(metrics, prefix) # writes to self.metrics_save_path
preds = flatten_list([x["preds"] for x in outputs]) preds = flatten_list([x["preds"] for x in outputs])
return {"log": metrics, "preds": preds, f"{prefix}_loss": loss, f"{prefix}_rouge": rouge_tensor} return {"log": metrics, "preds": preds, f"{prefix}_loss": loss, f"{prefix}_{self.val_metric}": rouge_tensor}
def save_metrics(self, latest_metrics, type_path) -> None:
self.metrics[type_path].append(latest_metrics)
save_json(self.metrics, self.metrics_save_path)
def save_metrics(self, metrics, prefix) -> None: def calc_generative_metrics(self, preds, target) -> Dict:
self.metrics[prefix].append(metrics) return calculate_rouge(preds, target)
pickle_save(self.metrics, self.metrics_save_path)
def _generative_step(self, batch: dict) -> dict: def _generative_step(self, batch: dict) -> dict:
pad_token_id = self.tokenizer.pad_token_id pad_token_id = self.tokenizer.pad_token_id
...@@ -154,7 +165,7 @@ class SummarizationModule(BaseTransformer): ...@@ -154,7 +165,7 @@ class SummarizationModule(BaseTransformer):
target = self.ids_to_clean_text(y) target = self.ids_to_clean_text(y)
loss_tensors = self._step(batch) loss_tensors = self._step(batch)
base_metrics = {name: loss for name, loss in zip(self.loss_names, loss_tensors)} base_metrics = {name: loss for name, loss in zip(self.loss_names, loss_tensors)}
rouge: Dict = calculate_rouge(preds, target) rouge: Dict = self.calc_generative_metrics(preds, target)
summ_len = np.mean(lmap(len, generated_ids)) summ_len = np.mean(lmap(len, generated_ids))
base_metrics.update(gen_time=gen_time, summ_len=summ_len, preds=preds, target=target, **rouge) base_metrics.update(gen_time=gen_time, summ_len=summ_len, preds=preds, target=target, **rouge)
return base_metrics return base_metrics
...@@ -259,15 +270,33 @@ class SummarizationModule(BaseTransformer): ...@@ -259,15 +270,33 @@ class SummarizationModule(BaseTransformer):
parser.add_argument("--n_train", type=int, default=-1, required=False, help="# examples. -1 means use all.") parser.add_argument("--n_train", type=int, default=-1, required=False, help="# examples. -1 means use all.")
parser.add_argument("--n_val", type=int, default=500, required=False, help="# examples. -1 means use all.") parser.add_argument("--n_val", type=int, default=500, required=False, help="# examples. -1 means use all.")
parser.add_argument("--n_test", type=int, default=-1, required=False, help="# examples. -1 means use all.") parser.add_argument("--n_test", type=int, default=-1, required=False, help="# examples. -1 means use all.")
parser.add_argument(
"--task", type=str, default="summarization", required=False, help="# examples. -1 means use all."
)
return parser return parser
class TranslationModule(SummarizationModule):
mode = "translation"
loss_names = ["loss"]
metric_names = ["bleu"]
val_metric = "bleu"
def calc_generative_metrics(self, preds, target) -> dict:
return calculate_bleu_score(preds, target)
def main(args, model=None) -> SummarizationModule: def main(args, model=None) -> SummarizationModule:
Path(args.output_dir).mkdir(exist_ok=True) Path(args.output_dir).mkdir(exist_ok=True)
if len(os.listdir(args.output_dir)) > 3 and args.do_train: if len(os.listdir(args.output_dir)) > 3 and args.do_train:
raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir)) raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))
if model is None: if model is None:
model: BaseTransformer = SummarizationModule(args) if args.task == "summarization":
model: SummarizationModule = SummarizationModule(args)
else:
model: SummarizationModule = TranslationModule(args)
dataset = Path(args.data_dir).name
if ( if (
args.logger == "default" args.logger == "default"
or args.fast_dev_run or args.fast_dev_run
...@@ -278,17 +307,17 @@ def main(args, model=None) -> SummarizationModule: ...@@ -278,17 +307,17 @@ def main(args, model=None) -> SummarizationModule:
elif args.logger == "wandb": elif args.logger == "wandb":
from pytorch_lightning.loggers import WandbLogger from pytorch_lightning.loggers import WandbLogger
logger = WandbLogger(name=model.output_dir.name) logger = WandbLogger(name=model.output_dir.name, project=dataset)
elif args.logger == "wandb_shared": elif args.logger == "wandb_shared":
from pytorch_lightning.loggers import WandbLogger from pytorch_lightning.loggers import WandbLogger
# TODO: separate LB for CNN, we should use Path(args.data_dir).name to determine the correct LB. logger = WandbLogger(name=model.output_dir.name, project=f"hf_{dataset}")
logger = WandbLogger(name=model.output_dir.name, project="hf_summarization")
trainer: pl.Trainer = generic_train( trainer: pl.Trainer = generic_train(
model, model,
args, args,
logging_callback=Seq2SeqLoggingCallback(), logging_callback=Seq2SeqLoggingCallback(),
checkpoint_callback=get_rouge2_checkpoint_callback(args.output_dir), checkpoint_callback=get_checkpoint_callback(args.output_dir, model.val_metric),
logger=logger, logger=logger,
# TODO: early stopping callback seems messed up # TODO: early stopping callback seems messed up
) )
......
# Add parent directory to python path to access lightning_base.py # Add parent directory to python path to access lightning_base.py
export PYTHONPATH="../":"${PYTHONPATH}" export PYTHONPATH="../":"${PYTHONPATH}"
# the proper usage is documented in the README, you need to specify data_dir, output_dir and model_name_or_path
# --model_name_or_path=t5-base for t5
# the proper usage is documented in the README
python finetune.py \ python finetune.py \
--model_name_or_path=facebook/bart-large \
--learning_rate=3e-5 \ --learning_rate=3e-5 \
--fp16 \ --fp16 \
--gpus 1 \ --gpus 1 \
...@@ -16,5 +11,4 @@ python finetune.py \ ...@@ -16,5 +11,4 @@ python finetune.py \
--n_val 1000 \ --n_val 1000 \
--val_check_interval 0.1 \ --val_check_interval 0.1 \
--sortish_sampler \ --sortish_sampler \
--max_target_length=56 \
$@ $@
#CNN_DIR = /home/shleifer/transformers_fork/examples/summarization/bart/cnn_dm
# Add parent directory to python path to access lightning_base.py # Add parent directory to python path to access lightning_base.py
export PYTHONPATH="../":"${PYTHONPATH}" export PYTHONPATH="../":"${PYTHONPATH}"
......
...@@ -9,9 +9,9 @@ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer ...@@ -9,9 +9,9 @@ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
try: try:
from .finetune import calculate_rouge, use_task_specific_params from .utils import calculate_rouge, use_task_specific_params, calculate_bleu_score
except ImportError: except ImportError:
from finetune import calculate_rouge, use_task_specific_params from utils import calculate_rouge, use_task_specific_params, calculate_bleu_score
DEFAULT_DEVICE = "cuda" if torch.cuda.is_available() else "cpu" DEFAULT_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
...@@ -22,8 +22,14 @@ def chunks(lst, n): ...@@ -22,8 +22,14 @@ def chunks(lst, n):
yield lst[i : i + n] yield lst[i : i + n]
def generate_summaries( def generate_summaries_or_translations(
examples: list, out_file: str, model_name: str, batch_size: int = 8, device: str = DEFAULT_DEVICE, fp16=False, examples: list,
out_file: str,
model_name: str,
batch_size: int = 8,
device: str = DEFAULT_DEVICE,
fp16=False,
**gen_kwargs,
) -> None: ) -> None:
fout = Path(out_file).open("w", encoding="utf-8") fout = Path(out_file).open("w", encoding="utf-8")
model_name = str(model_name) model_name = str(model_name)
...@@ -39,11 +45,10 @@ def generate_summaries( ...@@ -39,11 +45,10 @@ def generate_summaries(
for batch in tqdm(list(chunks(examples, batch_size))): for batch in tqdm(list(chunks(examples, batch_size))):
if "t5" in model_name: if "t5" in model_name:
batch = [model.config.prefix + text for text in batch] batch = [model.config.prefix + text for text in batch]
dct = tokenizer.batch_encode_plus(batch, max_length=1024, return_tensors="pt", pad_to_max_length=True).to( batch = tokenizer.batch_encode_plus(
device batch, max_length=1024, return_tensors="pt", truncation=True, pad_to_max_length=True
) ).to(device)
summaries = model.generate(**dct) summaries = model.generate(**batch, **gen_kwargs)
dec = tokenizer.batch_decode(summaries, skip_special_tokens=True, clean_up_tokenization_spaces=False) dec = tokenizer.batch_decode(summaries, skip_special_tokens=True, clean_up_tokenization_spaces=False)
for hypothesis in dec: for hypothesis in dec:
fout.write(hypothesis + "\n") fout.write(hypothesis + "\n")
...@@ -57,22 +62,26 @@ def run_generate(): ...@@ -57,22 +62,26 @@ def run_generate():
parser.add_argument("model_name", type=str, help="like facebook/bart-large-cnn,t5-base, etc.") parser.add_argument("model_name", type=str, help="like facebook/bart-large-cnn,t5-base, etc.")
parser.add_argument("--reference_path", type=str, required=False, help="like cnn_dm/test_reference_summaries.txt") parser.add_argument("--reference_path", type=str, required=False, help="like cnn_dm/test_reference_summaries.txt")
parser.add_argument("--score_path", type=str, required=False, help="where to save the rouge score in json format") parser.add_argument("--score_path", type=str, required=False, help="where to save the rouge score in json format")
parser.add_argument("--metric", type=str, choices=["bleu", "rouge"], default="rouge")
parser.add_argument("--device", type=str, required=False, default=DEFAULT_DEVICE, help="cuda, cuda:1, cpu etc.") parser.add_argument("--device", type=str, required=False, default=DEFAULT_DEVICE, help="cuda, cuda:1, cpu etc.")
parser.add_argument("--bs", type=int, default=8, required=False, help="batch size") parser.add_argument("--bs", type=int, default=8, required=False, help="batch size")
parser.add_argument("--fp16", action="store_true") parser.add_argument("--fp16", action="store_true")
args = parser.parse_args() args = parser.parse_args()
examples = [" " + x.rstrip() if "t5" in args.model_name else x.rstrip() for x in open(args.input_path).readlines()] examples = [" " + x.rstrip() if "t5" in args.model_name else x.rstrip() for x in open(args.input_path).readlines()]
generate_summaries( generate_summaries_or_translations(
examples, args.output_path, args.model_name, batch_size=args.bs, device=args.device, fp16=args.fp16 examples, args.output_path, args.model_name, batch_size=args.bs, device=args.device, fp16=args.fp16
) )
if args.score_path is not None:
output_lns = [x.rstrip() for x in open(args.output_path).readlines()] output_lns = [x.rstrip() for x in open(args.output_path).readlines()]
scores = {}
if args.reference_path is not None:
score_fn = {"bleu": calculate_bleu_score, "rouge": calculate_rouge}[args.metric]
reference_lns = [x.rstrip() for x in open(args.reference_path).readlines()] reference_lns = [x.rstrip() for x in open(args.reference_path).readlines()]
scores: dict = score_fn(output_lns, reference_lns)
rouge: dict = calculate_rouge(output_lns, reference_lns) if args.score_path is not None:
json.dump(scores, open("score_path", "w+"))
json.dump(rouge, open("score_path", "w+")) return scores
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment