Unverified Commit 09a2f406 authored by Sam Shleifer's avatar Sam Shleifer Committed by GitHub
Browse files

Seq2SeqDataset uses linecache to save memory by @Pradhy729 (#5792)


Co-authored-by: default avatarPradhy729 <49659913+Pradhy729@users.noreply.github.com>
parent 4b506a37
......@@ -7,27 +7,24 @@ For `bertabs` instructions, see `bertabs/README.md`.
### Data
CNN/DailyMail data
XSUM Data:
```bash
cd examples/seq2seq
wget https://s3.amazonaws.com/datasets.huggingface.co/summarization/cnn_dm.tgz
tar -xzvf cnn_dm.tgz
export CNN_DIR=${PWD}/cnn_dm
wget https://s3.amazonaws.com/datasets.huggingface.co/summarization/xsum.tar.gz
tar -xzvf xsum.tar.gz
export XSUM_DIR=${PWD}/xsum
```
this should make a directory called cnn_dm/ with files like `test.source`.
To use your own data, copy that files format. Each article to be summarized is on its own line.
XSUM Data:
CNN/DailyMail data
```bash
cd examples/seq2seq
wget https://s3.amazonaws.com/datasets.huggingface.co/summarization/xsum.tar.gz
tar -xzvf xsum.tar.gz
export XSUM_DIR=${PWD}/xsum
```
wget https://s3.amazonaws.com/datasets.huggingface.co/summarization/cnn_dm.tgz
tar -xzvf cnn_dm.tgz
export CNN_DIR=${PWD}/cnn_dm
```
WMT16 English-Romanian Translation Data:
```bash
......@@ -40,7 +37,7 @@ export ENRO_DIR=${PWD}/wmt_en_ro
If you are using your own data, it must be formatted as one directory with 6 files: train.source, train.target, val.source, val.target, test.source, test.target.
The `.source` files are the input, the `.target` files are the desired output.
### Tips and Tricks
General Tips:
......@@ -64,6 +61,10 @@ Summarization Tips:
- If you are finetuning on your own dataset, start from `distilbart-cnn-12-6` if you want long summaries and `distilbart-xsum-12-6` if you want short summaries.
(It rarely makes sense to start from `bart-large` unless you are a researching finetuning methods).
**Update 2018-07-18**
Datasets: Seq2SeqDataset will be used for all models besides MBart, for which MBartDataset will be used.**
A new dataset is needed to support multilingual tasks.
### Summarization Finetuning
Run/modify `finetune.sh`
......@@ -78,8 +79,6 @@ The following command should work on a 16GB GPU:
--model_name_or_path facebook/bart-large
```
### Translation Finetuning
First, follow the wmt_en_ro download instructions.
......@@ -124,23 +123,6 @@ from transformers import AutoModelForSeq2SeqLM
model = AutoModelForSeq2SeqLM.from_pretrained(f'{output_dir}/best_tfmr')
```
#### XSUM Shared Task
Compare XSUM results with others by using `--logger_name wandb_shared`. This requires `wandb` registration.
Here is an example command, but you can do whatever you want. Hopefully this will make debugging and collaboration easier!
```bash
WANDB_PROJECT='hf_xsum' ./finetune.sh \
--data_dir $XSUM_DIR \
--output_dir xsum_frozen_embs \
--model_name_or_path facebook/bart-large \
--train_batch_size 16 --eval_batch_size 16 --freeze_embeds --freeze_encoder \
--num_train_epochs 6 \
--max_target_length=60 --val_max_target_length=60 --test_max_target_length=100 \
--logger_name wandb
```
You can see your wandb logs [here](https://app.wandb.ai/sshleifer/hf_xsum?workspace=user-)
### Evaluation Commands
To create summaries for each article in dataset, we use `run_eval.py`, here are a few commands that run eval for different tasks and models.
......
......@@ -15,28 +15,15 @@ from transformers import AdamW, BartConfig, BartForConditionalGeneration, T5Conf
try:
from .finetune import SummarizationModule
from .initialization_utils import init_student, copy_layers
from .utils import (
use_task_specific_params,
SummarizationDataset,
pickle_load,
freeze_params,
assert_all_frozen,
any_requires_grad,
)
from .finetune import main as ft_main
from .initialization_utils import init_student, copy_layers
from .utils import use_task_specific_params, pickle_load, freeze_params, assert_all_frozen, any_requires_grad
except ImportError:
from finetune import SummarizationModule
from finetune import main as ft_main
from initialization_utils import init_student, copy_layers
from utils import (
use_task_specific_params,
SummarizationDataset,
pickle_load,
freeze_params,
assert_all_frozen,
any_requires_grad,
)
from utils import use_task_specific_params, pickle_load, freeze_params, assert_all_frozen, any_requires_grad
class BartSummarizationDistiller(SummarizationModule):
......@@ -115,11 +102,6 @@ class BartSummarizationDistiller(SummarizationModule):
if self.different_encoder:
copy_layers(teacher.encoder.block, student.encoder.block, e_layers_to_copy)
def get_dataset(self, type_path) -> SummarizationDataset:
n_obs = self.n_obs[type_path]
dataset = SummarizationDataset(self.tokenizer, type_path=type_path, n_obs=n_obs, **self.dataset_kwargs)
return dataset
def calc_mse_loss(self, teacher_outputs: torch.Tensor, student_outputs: torch.Tensor, mask) -> torch.FloatTensor:
if mask is not None:
# mask has False at padding_idx
......
......@@ -21,7 +21,6 @@ try:
from .utils import (
assert_all_frozen,
use_task_specific_params,
SummarizationDataset,
lmap,
flatten_list,
pickle_save,
......@@ -32,12 +31,17 @@ try:
get_git_info,
ROUGE_KEYS,
calculate_bleu_score,
Seq2SeqDataset,
MBartDataset,
)
from .callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback
except ImportError:
from utils import (
Seq2SeqDataset,
MBartDataset,
assert_all_frozen,
use_task_specific_params,
SummarizationDataset,
lmap,
flatten_list,
pickle_save,
......@@ -48,7 +52,6 @@ except ImportError:
get_git_info,
ROUGE_KEYS,
calculate_bleu_score,
assert_all_frozen,
)
from callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback
......@@ -100,6 +103,7 @@ class SummarizationModule(BaseTransformer):
self.hparams.git_sha = get_git_info()["repo_sha"]
self.num_workers = hparams.num_workers
self.decoder_start_token_id = None
self.dataset_class = Seq2SeqDataset
def freeze_embeds(self):
"""Freeze token embeddings and positional embeddings for bart, just token embeddings for t5."""
......@@ -163,7 +167,7 @@ class SummarizationModule(BaseTransformer):
def _generative_step(self, batch: dict) -> dict:
pad_token_id = self.tokenizer.pad_token_id
source_ids, source_mask, y = SummarizationDataset.trim_seq2seq_batch(batch, pad_token_id)
source_ids, source_mask, y = Seq2SeqDataset.trim_seq2seq_batch(batch, pad_token_id)
t0 = time.time()
generated_ids = self.model.generate(
input_ids=source_ids,
......@@ -187,10 +191,10 @@ class SummarizationModule(BaseTransformer):
def test_epoch_end(self, outputs):
return self.validation_epoch_end(outputs, prefix="test")
def get_dataset(self, type_path) -> SummarizationDataset:
def get_dataset(self, type_path) -> Seq2SeqDataset:
n_obs = self.n_obs[type_path]
max_target_length = self.target_lens[type_path]
dataset = SummarizationDataset(
dataset = self.dataset_class(
self.tokenizer,
type_path=type_path,
n_obs=n_obs,
......@@ -303,6 +307,8 @@ class TranslationModule(SummarizationModule):
self.dataset_kwargs["tgt_lang"] = hparams.tgt_lang
if self.model.config.decoder_start_token_id is None and isinstance(self.tokenizer, MBartTokenizer):
self.decoder_start_token_id = self.tokenizer.lang_code_to_id[hparams.tgt_lang]
if isinstance(self.tokenizer, MBartTokenizer):
self.dataset_class = MBartDataset
def calc_generative_metrics(self, preds, target) -> dict:
return calculate_bleu_score(preds, target)
......
......@@ -9,16 +9,17 @@ from unittest.mock import patch
import pytest
import torch
from pytest import param
from torch.utils.data import DataLoader
from transformers import AutoTokenizer
from transformers import AutoTokenizer, MBartTokenizer
from transformers.testing_utils import require_multigpu
from .distillation import distill_main, evaluate_checkpoint
from .finetune import main
from .pack_dataset import pack_data_dir
from .run_eval import generate_summaries_or_translations, run_generate
from .utils import SummarizationDataset, lmap, load_json
from .utils import MBartDataset, Seq2SeqDataset, lmap, load_json
logging.basicConfig(level=logging.DEBUG)
......@@ -26,6 +27,7 @@ logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger()
CUDA_AVAILABLE = torch.cuda.is_available()
CHEAP_ARGS = {
"label_smoothing_eps": 0.2,
"logger_name": "default",
"length_penalty": 0.5,
"cache_dir": "",
......@@ -80,11 +82,11 @@ CHEAP_ARGS = {
def _dump_articles(path: Path, articles: list):
with path.open("w") as f:
f.write("\n".join(articles))
content = "\n".join(articles)
Path(path).open("w").writelines(content)
ARTICLES = [" Sam ate lunch today", "Sams lunch ingredients"]
ARTICLES = [" Sam ate lunch today.", "Sams lunch ingredients."]
SUMMARIES = ["A very interesting story about what I ate for lunch.", "Avocado, celery, turkey, coffee"]
T5_TINY = "patrickvonplaten/t5-tiny-random"
BART_TINY = "sshleifer/bart-tiny-random"
......@@ -208,7 +210,7 @@ def test_run_eval_bart(model):
@pytest.mark.parametrize(
["model"], [pytest.param(T5_TINY), pytest.param(BART_TINY), pytest.param(MBART_TINY), pytest.param(MARIAN_TINY)]
["model"], [pytest.param(T5_TINY), pytest.param(BART_TINY), pytest.param(MBART_TINY), pytest.param(MARIAN_TINY)],
)
def test_finetune(model):
args_d: dict = CHEAP_ARGS.copy()
......@@ -260,22 +262,50 @@ def test_pack_dataset():
assert orig_paths == new_paths
@pytest.mark.parametrize(
["tok"], [pytest.param(T5_TINY), pytest.param(BART_TINY), pytest.param(MBART_TINY), pytest.param(MARIAN_TINY)]
)
def test_dataset(tok):
tokenizer = AutoTokenizer.from_pretrained(tok)
def test_mbart_dataset_truncation():
tokenizer = MBartTokenizer.from_pretrained(MBART_TINY)
tmp_dir = make_test_data_dir()
max_len_source = max(len(tokenizer.encode(a)) for a in ARTICLES)
max_len_target = max(len(tokenizer.encode(a)) for a in SUMMARIES)
trunc_target = 4
train_dataset = SummarizationDataset(
trunc = 4
src_lang, tgt_lang = "ro_RO", "de_DE" # NOT WHAT IT WAS TRAINED ON
train_dataset = MBartDataset(
tokenizer,
data_dir=tmp_dir,
type_path="train",
max_source_length=20,
max_target_length=trunc_target,
tgt_lang="ro_RO",
max_source_length=trunc,
max_target_length=1000, # ignored
src_lang=src_lang,
tgt_lang=tgt_lang,
)
dataloader = DataLoader(train_dataset, batch_size=2, collate_fn=train_dataset.collate_fn)
for batch in dataloader:
assert isinstance(batch, dict)
assert batch["attention_mask"].shape == batch["input_ids"].shape
# show that articles were trimmed.
assert batch["input_ids"].shape[1] == trunc
# show that targets are the same len
assert batch["decoder_input_ids"].shape[1] == trunc
# check language codes in correct place
assert batch["decoder_input_ids"][0, 0].item() == tokenizer.lang_code_to_id[tgt_lang]
assert batch["decoder_input_ids"][0, -1].item() == tokenizer.eos_token_id
assert batch["input_ids"][0, -2].item() == tokenizer.eos_token_id
assert batch["input_ids"][0, -1].item() == tokenizer.lang_code_to_id[src_lang]
assert max_len_target > trunc # Truncated
assert max_len_source > trunc
break # No need to test every batch
@pytest.mark.parametrize(["tok"], [pytest.param(T5_TINY), pytest.param(BART_TINY), param(MARIAN_TINY)])
def test_summarization_dataset_truncation(tok):
tokenizer = AutoTokenizer.from_pretrained(tok)
tmp_dir = make_test_data_dir()
max_len_source = max(len(tokenizer.encode(a)) for a in ARTICLES)
max_len_target = max(len(tokenizer.encode(a)) for a in SUMMARIES)
trunc_target = 4
train_dataset = Seq2SeqDataset(
tokenizer, data_dir=tmp_dir, type_path="train", max_source_length=20, max_target_length=trunc_target,
)
dataloader = DataLoader(train_dataset, batch_size=2, collate_fn=train_dataset.collate_fn)
for batch in dataloader:
......@@ -286,3 +316,4 @@ def test_dataset(tok):
# show that targets were truncated
assert batch["decoder_input_ids"].shape[1] == trunc_target # Truncated
assert max_len_target > trunc_target # Truncated
break # No need to test every batch
import itertools
import json
import linecache
import os
import pickle
import warnings
from logging import getLogger
from pathlib import Path
from typing import Callable, Dict, Iterable, List
......@@ -13,50 +15,20 @@ from rouge_score import rouge_scorer, scoring
from sacrebleu import corpus_bleu
from torch import nn
from torch.utils.data import Dataset, Sampler
from tqdm import tqdm
from transformers import BartTokenizer
def encode_file(
tokenizer,
data_path,
max_length,
pad_to_max_length=True,
return_tensors="pt",
overwrite_cache=False,
prefix="",
tok_name="",
):
def encode_line(tokenizer, line, max_length, pad_to_max_length=True, return_tensors="pt"):
extra_kw = {"add_prefix_space": True} if isinstance(tokenizer, BartTokenizer) else {}
cache_path = Path(f"{data_path}_{tok_name}{max_length}.pt")
if not overwrite_cache and cache_path.exists():
try:
examples = torch.load(cache_path)
assert isinstance(examples, list)
return examples
except Exception:
print(f"failed to load from {cache_path}, retokenizing {data_path}")
data_path = Path(data_path)
lns = lmap(str.strip, data_path.open().readlines())
lns = [prefix + text for text in lns]
assert lns, f"found empty file at {data_path}"
examples = []
for text in tqdm(lns, desc=f"Tokenizing {data_path.name}"):
tokenized = tokenizer(
[text],
max_length=max_length,
padding="max_length" if pad_to_max_length else None,
truncation=True,
return_tensors=return_tensors,
**extra_kw,
)
assert tokenized.input_ids.shape[1] == max_length
examples.append(tokenized)
torch.save(lmap(dict, examples), cache_path.open("wb"))
return examples
return tokenizer(
[line],
max_length=max_length,
padding="max_length" if pad_to_max_length else None,
truncation=True,
return_tensors=return_tensors,
**extra_kw,
)
def lmap(f: Callable, x: Iterable) -> List:
......@@ -80,73 +52,111 @@ def trim_batch(
return (input_ids[:, keep_column_mask], attention_mask[:, keep_column_mask])
class SummarizationDataset(Dataset):
class Seq2SeqDataset(Dataset):
def __init__(
self,
tokenizer,
data_dir,
max_source_length,
max_target_length,
type_path="train",
max_source_length=1024,
max_target_length=56,
n_obs=None,
overwrite_cache=False,
prefix="",
src_lang=None,
tgt_lang=None,
prefix="",
):
super().__init__()
# FIXME: the rstrip logic strips all the chars, it seems.
tok_name = tokenizer.__class__.__name__.lower().rstrip("tokenizer")
if hasattr(tokenizer, "set_lang") and src_lang is not None:
tokenizer.set_lang(src_lang) # HACK: only applies to mbart
self.source = encode_file(
tokenizer,
os.path.join(data_dir, type_path + ".source"),
max_source_length,
overwrite_cache=overwrite_cache,
prefix=prefix,
tok_name=tok_name,
)
tgt_path = os.path.join(data_dir, type_path + ".target")
if hasattr(tokenizer, "set_lang"):
assert tgt_lang is not None, "--tgt_lang must be passed to build a translation"
tokenizer.set_lang(tgt_lang) # HACK: only applies to mbart
self.target = encode_file(
tokenizer, tgt_path, max_target_length, overwrite_cache=overwrite_cache, tok_name=tok_name
)
self.src_file = Path(data_dir).joinpath(type_path + ".source")
self.tgt_file = Path(data_dir).joinpath(type_path + ".target")
self.src_lens = self.get_char_lens(self.src_file)
self.max_source_length = max_source_length
self.max_target_length = max_target_length
assert min(self.src_lens) > 0, f"found empty line in {self.src_file}"
self.tokenizer = tokenizer
self.prefix = prefix
if n_obs is not None:
self.source = self.source[:n_obs]
self.target = self.target[:n_obs]
self.pad_token_id = tokenizer.pad_token_id
self.src_lens = self.src_lens[:n_obs]
self.pad_token_id = self.tokenizer.pad_token_id
self.src_lang = src_lang
self.tgt_lang = tgt_lang
def __len__(self):
return len(self.source)
return len(self.src_lens)
def __getitem__(self, index) -> Dict[str, torch.Tensor]:
index = index + 1 # linecache starts at 1
source_line = self.prefix + linecache.getline(str(self.src_file), index).rstrip("\n")
tgt_line = linecache.getline(str(self.tgt_file), index).rstrip("\n")
assert source_line, f"empty source line for index {index}"
assert tgt_line, f"empty tgt line for index {index}"
source_inputs = encode_line(self.tokenizer, source_line, self.max_source_length)
target_inputs = encode_line(self.tokenizer, tgt_line, self.max_target_length)
source_ids = source_inputs["input_ids"].squeeze()
target_ids = target_inputs["input_ids"].squeeze()
src_mask = source_inputs["attention_mask"].squeeze()
return {
"input_ids": source_ids,
"attention_mask": src_mask,
"decoder_input_ids": target_ids,
}
def __getitem__(self, index):
source_ids = self.source[index]["input_ids"].squeeze()
target_ids = self.target[index]["input_ids"].squeeze()
src_mask = self.source[index]["attention_mask"].squeeze()
return {"input_ids": source_ids, "attention_mask": src_mask, "decoder_input_ids": target_ids}
@staticmethod
def get_char_lens(data_file):
return [len(x) for x in Path(data_file).open().readlines()]
@staticmethod
def trim_seq2seq_batch(batch, pad_token_id):
def trim_seq2seq_batch(batch, pad_token_id) -> tuple:
y = trim_batch(batch["decoder_input_ids"], pad_token_id)
source_ids, source_mask = trim_batch(batch["input_ids"], pad_token_id, attention_mask=batch["attention_mask"])
return source_ids, source_mask, y
def collate_fn(self, batch) -> dict:
def collate_fn(self, batch) -> Dict[str, torch.Tensor]:
input_ids = torch.stack([x["input_ids"] for x in batch])
masks = torch.stack([x["attention_mask"] for x in batch])
target_ids = torch.stack([x["decoder_input_ids"] for x in batch])
pad_token_id = self.pad_token_id
y = trim_batch(target_ids, pad_token_id)
source_ids, source_mask = trim_batch(input_ids, pad_token_id, attention_mask=masks)
batch = {"input_ids": source_ids, "attention_mask": source_mask, "decoder_input_ids": y}
batch = {
"input_ids": source_ids,
"attention_mask": source_mask,
"decoder_input_ids": y,
}
return batch
def make_sortish_sampler(self, batch_size):
lens = [x["input_ids"].ne(self.pad_token_id).sum() for x in self.source]
return SortishSampler(lens, batch_size)
return SortishSampler(self.src_lens, batch_size)
class MBartDataset(Seq2SeqDataset):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if self.max_source_length != self.max_target_length:
warnings.warn(
f"Mbart will ignore max_target_length = {self.max_target_length} and use {self.max_source_length} for both sides."
)
def __getitem__(self, index) -> Dict[str, str]:
index = index + 1 # linecache starts at 1
source_line = self.prefix + linecache.getline(str(self.src_file), index).rstrip("\n")
tgt_line = linecache.getline(str(self.tgt_file), index).rstrip("\n")
assert source_line, f"empty source line for index {index}"
assert tgt_line, f"empty tgt line for index {index}"
return {
"tgt_texts": source_line,
"src_texts": tgt_line,
}
def collate_fn(self, batch) -> Dict[str, torch.Tensor]:
batch_encoding = self.tokenizer.prepare_translation_batch(
[x["src_texts"] for x in batch],
src_lang=self.src_lang,
tgt_texts=[x["tgt_texts"] for x in batch],
tgt_lang=self.tgt_lang,
max_length=self.max_source_length,
)
return batch_encoding.data
class SortishSampler(Sampler):
......
......@@ -118,12 +118,7 @@ class MBartTokenizer(XLMRobertaTokenizer):
self.fairseq_tokens_to_ids.update(self.lang_code_to_id)
self.fairseq_ids_to_tokens = {v: k for k, v in self.fairseq_tokens_to_ids.items()}
self._additional_special_tokens = list(self.lang_code_to_id.keys())
self.reset_special_tokens()
def reset_special_tokens(self) -> None:
"""Reset the special tokens to the source lang setting. No prefix and suffix=[eos, cur_lang_code]."""
self.prefix_tokens = []
self.suffix_tokens = [self.eos_token_id, self.cur_lang_code]
self.set_src_lang_special_tokens(kwargs.get("src_lang", "en_XX"))
def build_inputs_with_special_tokens(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
......@@ -183,12 +178,6 @@ class MBartTokenizer(XLMRobertaTokenizer):
return prefix_ones + ([0] * len(token_ids_0)) + suffix_ones
return prefix_ones + ([0] * len(token_ids_0)) + ([0] * len(token_ids_1)) + suffix_ones
def set_lang(self, lang: str) -> None:
"""Set the current language code in order to call tokenizer properly."""
self.cur_lang_code = self.lang_code_to_id[lang]
self.prefix_tokens = [self.cur_lang_code]
self.suffix_tokens = [self.eos_token_id]
def prepare_translation_batch(
self,
src_texts: List[str],
......@@ -215,7 +204,7 @@ class MBartTokenizer(XLMRobertaTokenizer):
"""
if max_length is None:
max_length = self.max_len
self.cur_lang_code = self.lang_code_to_id[src_lang]
self.set_src_lang_special_tokens(src_lang)
model_inputs: BatchEncoding = self(
src_texts,
add_special_tokens=True,
......@@ -227,7 +216,7 @@ class MBartTokenizer(XLMRobertaTokenizer):
)
if tgt_texts is None:
return model_inputs
self.set_lang(tgt_lang)
self.set_tgt_lang_special_tokens(tgt_lang)
decoder_inputs: BatchEncoding = self(
tgt_texts,
add_special_tokens=True,
......@@ -239,6 +228,18 @@ class MBartTokenizer(XLMRobertaTokenizer):
)
for k, v in decoder_inputs.items():
model_inputs[f"decoder_{k}"] = v
self.cur_lang_code = self.lang_code_to_id[src_lang]
self.reset_special_tokens() # sets to src_lang
self.set_src_lang_special_tokens(src_lang) # sets to src_lang
return model_inputs
def set_src_lang_special_tokens(self, src_lang) -> None:
"""Reset the special tokens to the source lang setting. No prefix and suffix=[eos, cur_lang_code]."""
self.cur_lang_code = self.lang_code_to_id[src_lang]
self.prefix_tokens = []
self.suffix_tokens = [self.eos_token_id, self.cur_lang_code]
def set_tgt_lang_special_tokens(self, lang: str) -> None:
"""Reset the special tokens to the target language setting. Prefix [tgt_lang_code], suffix =[eos]."""
self.cur_lang_code = self.lang_code_to_id[lang]
self.prefix_tokens = [self.cur_lang_code]
self.suffix_tokens = [self.eos_token_id]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment