Unverified Commit 09a2f406 authored by Sam Shleifer's avatar Sam Shleifer Committed by GitHub
Browse files

Seq2SeqDataset uses linecache to save memory by @Pradhy729 (#5792)


Co-authored-by: default avatarPradhy729 <49659913+Pradhy729@users.noreply.github.com>
parent 4b506a37
...@@ -7,27 +7,24 @@ For `bertabs` instructions, see `bertabs/README.md`. ...@@ -7,27 +7,24 @@ For `bertabs` instructions, see `bertabs/README.md`.
### Data ### Data
XSUM Data:
CNN/DailyMail data
```bash ```bash
cd examples/seq2seq cd examples/seq2seq
wget https://s3.amazonaws.com/datasets.huggingface.co/summarization/cnn_dm.tgz wget https://s3.amazonaws.com/datasets.huggingface.co/summarization/xsum.tar.gz
tar -xzvf cnn_dm.tgz tar -xzvf xsum.tar.gz
export XSUM_DIR=${PWD}/xsum
export CNN_DIR=${PWD}/cnn_dm
``` ```
this should make a directory called cnn_dm/ with files like `test.source`. this should make a directory called cnn_dm/ with files like `test.source`.
To use your own data, copy that files format. Each article to be summarized is on its own line. To use your own data, copy that files format. Each article to be summarized is on its own line.
XSUM Data: CNN/DailyMail data
```bash ```bash
cd examples/seq2seq cd examples/seq2seq
wget https://s3.amazonaws.com/datasets.huggingface.co/summarization/xsum.tar.gz wget https://s3.amazonaws.com/datasets.huggingface.co/summarization/cnn_dm.tgz
tar -xzvf xsum.tar.gz tar -xzvf cnn_dm.tgz
export XSUM_DIR=${PWD}/xsum
```
export CNN_DIR=${PWD}/cnn_dm
```
WMT16 English-Romanian Translation Data: WMT16 English-Romanian Translation Data:
```bash ```bash
...@@ -64,6 +61,10 @@ Summarization Tips: ...@@ -64,6 +61,10 @@ Summarization Tips:
- If you are finetuning on your own dataset, start from `distilbart-cnn-12-6` if you want long summaries and `distilbart-xsum-12-6` if you want short summaries. - If you are finetuning on your own dataset, start from `distilbart-cnn-12-6` if you want long summaries and `distilbart-xsum-12-6` if you want short summaries.
(It rarely makes sense to start from `bart-large` unless you are a researching finetuning methods). (It rarely makes sense to start from `bart-large` unless you are a researching finetuning methods).
**Update 2018-07-18**
Datasets: Seq2SeqDataset will be used for all models besides MBart, for which MBartDataset will be used.**
A new dataset is needed to support multilingual tasks.
### Summarization Finetuning ### Summarization Finetuning
Run/modify `finetune.sh` Run/modify `finetune.sh`
...@@ -78,8 +79,6 @@ The following command should work on a 16GB GPU: ...@@ -78,8 +79,6 @@ The following command should work on a 16GB GPU:
--model_name_or_path facebook/bart-large --model_name_or_path facebook/bart-large
``` ```
### Translation Finetuning ### Translation Finetuning
First, follow the wmt_en_ro download instructions. First, follow the wmt_en_ro download instructions.
...@@ -124,23 +123,6 @@ from transformers import AutoModelForSeq2SeqLM ...@@ -124,23 +123,6 @@ from transformers import AutoModelForSeq2SeqLM
model = AutoModelForSeq2SeqLM.from_pretrained(f'{output_dir}/best_tfmr') model = AutoModelForSeq2SeqLM.from_pretrained(f'{output_dir}/best_tfmr')
``` ```
#### XSUM Shared Task
Compare XSUM results with others by using `--logger_name wandb_shared`. This requires `wandb` registration.
Here is an example command, but you can do whatever you want. Hopefully this will make debugging and collaboration easier!
```bash
WANDB_PROJECT='hf_xsum' ./finetune.sh \
--data_dir $XSUM_DIR \
--output_dir xsum_frozen_embs \
--model_name_or_path facebook/bart-large \
--train_batch_size 16 --eval_batch_size 16 --freeze_embeds --freeze_encoder \
--num_train_epochs 6 \
--max_target_length=60 --val_max_target_length=60 --test_max_target_length=100 \
--logger_name wandb
```
You can see your wandb logs [here](https://app.wandb.ai/sshleifer/hf_xsum?workspace=user-)
### Evaluation Commands ### Evaluation Commands
To create summaries for each article in dataset, we use `run_eval.py`, here are a few commands that run eval for different tasks and models. To create summaries for each article in dataset, we use `run_eval.py`, here are a few commands that run eval for different tasks and models.
......
...@@ -15,28 +15,15 @@ from transformers import AdamW, BartConfig, BartForConditionalGeneration, T5Conf ...@@ -15,28 +15,15 @@ from transformers import AdamW, BartConfig, BartForConditionalGeneration, T5Conf
try: try:
from .finetune import SummarizationModule from .finetune import SummarizationModule
from .initialization_utils import init_student, copy_layers
from .utils import (
use_task_specific_params,
SummarizationDataset,
pickle_load,
freeze_params,
assert_all_frozen,
any_requires_grad,
)
from .finetune import main as ft_main from .finetune import main as ft_main
from .initialization_utils import init_student, copy_layers
from .utils import use_task_specific_params, pickle_load, freeze_params, assert_all_frozen, any_requires_grad
except ImportError: except ImportError:
from finetune import SummarizationModule from finetune import SummarizationModule
from finetune import main as ft_main from finetune import main as ft_main
from initialization_utils import init_student, copy_layers from initialization_utils import init_student, copy_layers
from utils import ( from utils import use_task_specific_params, pickle_load, freeze_params, assert_all_frozen, any_requires_grad
use_task_specific_params,
SummarizationDataset,
pickle_load,
freeze_params,
assert_all_frozen,
any_requires_grad,
)
class BartSummarizationDistiller(SummarizationModule): class BartSummarizationDistiller(SummarizationModule):
...@@ -115,11 +102,6 @@ class BartSummarizationDistiller(SummarizationModule): ...@@ -115,11 +102,6 @@ class BartSummarizationDistiller(SummarizationModule):
if self.different_encoder: if self.different_encoder:
copy_layers(teacher.encoder.block, student.encoder.block, e_layers_to_copy) copy_layers(teacher.encoder.block, student.encoder.block, e_layers_to_copy)
def get_dataset(self, type_path) -> SummarizationDataset:
n_obs = self.n_obs[type_path]
dataset = SummarizationDataset(self.tokenizer, type_path=type_path, n_obs=n_obs, **self.dataset_kwargs)
return dataset
def calc_mse_loss(self, teacher_outputs: torch.Tensor, student_outputs: torch.Tensor, mask) -> torch.FloatTensor: def calc_mse_loss(self, teacher_outputs: torch.Tensor, student_outputs: torch.Tensor, mask) -> torch.FloatTensor:
if mask is not None: if mask is not None:
# mask has False at padding_idx # mask has False at padding_idx
......
...@@ -21,7 +21,6 @@ try: ...@@ -21,7 +21,6 @@ try:
from .utils import ( from .utils import (
assert_all_frozen, assert_all_frozen,
use_task_specific_params, use_task_specific_params,
SummarizationDataset,
lmap, lmap,
flatten_list, flatten_list,
pickle_save, pickle_save,
...@@ -32,12 +31,17 @@ try: ...@@ -32,12 +31,17 @@ try:
get_git_info, get_git_info,
ROUGE_KEYS, ROUGE_KEYS,
calculate_bleu_score, calculate_bleu_score,
Seq2SeqDataset,
MBartDataset,
) )
from .callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback from .callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback
except ImportError: except ImportError:
from utils import ( from utils import (
Seq2SeqDataset,
MBartDataset,
assert_all_frozen,
use_task_specific_params, use_task_specific_params,
SummarizationDataset,
lmap, lmap,
flatten_list, flatten_list,
pickle_save, pickle_save,
...@@ -48,7 +52,6 @@ except ImportError: ...@@ -48,7 +52,6 @@ except ImportError:
get_git_info, get_git_info,
ROUGE_KEYS, ROUGE_KEYS,
calculate_bleu_score, calculate_bleu_score,
assert_all_frozen,
) )
from callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback from callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback
...@@ -100,6 +103,7 @@ class SummarizationModule(BaseTransformer): ...@@ -100,6 +103,7 @@ class SummarizationModule(BaseTransformer):
self.hparams.git_sha = get_git_info()["repo_sha"] self.hparams.git_sha = get_git_info()["repo_sha"]
self.num_workers = hparams.num_workers self.num_workers = hparams.num_workers
self.decoder_start_token_id = None self.decoder_start_token_id = None
self.dataset_class = Seq2SeqDataset
def freeze_embeds(self): def freeze_embeds(self):
"""Freeze token embeddings and positional embeddings for bart, just token embeddings for t5.""" """Freeze token embeddings and positional embeddings for bart, just token embeddings for t5."""
...@@ -163,7 +167,7 @@ class SummarizationModule(BaseTransformer): ...@@ -163,7 +167,7 @@ class SummarizationModule(BaseTransformer):
def _generative_step(self, batch: dict) -> dict: def _generative_step(self, batch: dict) -> dict:
pad_token_id = self.tokenizer.pad_token_id pad_token_id = self.tokenizer.pad_token_id
source_ids, source_mask, y = SummarizationDataset.trim_seq2seq_batch(batch, pad_token_id) source_ids, source_mask, y = Seq2SeqDataset.trim_seq2seq_batch(batch, pad_token_id)
t0 = time.time() t0 = time.time()
generated_ids = self.model.generate( generated_ids = self.model.generate(
input_ids=source_ids, input_ids=source_ids,
...@@ -187,10 +191,10 @@ class SummarizationModule(BaseTransformer): ...@@ -187,10 +191,10 @@ class SummarizationModule(BaseTransformer):
def test_epoch_end(self, outputs): def test_epoch_end(self, outputs):
return self.validation_epoch_end(outputs, prefix="test") return self.validation_epoch_end(outputs, prefix="test")
def get_dataset(self, type_path) -> SummarizationDataset: def get_dataset(self, type_path) -> Seq2SeqDataset:
n_obs = self.n_obs[type_path] n_obs = self.n_obs[type_path]
max_target_length = self.target_lens[type_path] max_target_length = self.target_lens[type_path]
dataset = SummarizationDataset( dataset = self.dataset_class(
self.tokenizer, self.tokenizer,
type_path=type_path, type_path=type_path,
n_obs=n_obs, n_obs=n_obs,
...@@ -303,6 +307,8 @@ class TranslationModule(SummarizationModule): ...@@ -303,6 +307,8 @@ class TranslationModule(SummarizationModule):
self.dataset_kwargs["tgt_lang"] = hparams.tgt_lang self.dataset_kwargs["tgt_lang"] = hparams.tgt_lang
if self.model.config.decoder_start_token_id is None and isinstance(self.tokenizer, MBartTokenizer): if self.model.config.decoder_start_token_id is None and isinstance(self.tokenizer, MBartTokenizer):
self.decoder_start_token_id = self.tokenizer.lang_code_to_id[hparams.tgt_lang] self.decoder_start_token_id = self.tokenizer.lang_code_to_id[hparams.tgt_lang]
if isinstance(self.tokenizer, MBartTokenizer):
self.dataset_class = MBartDataset
def calc_generative_metrics(self, preds, target) -> dict: def calc_generative_metrics(self, preds, target) -> dict:
return calculate_bleu_score(preds, target) return calculate_bleu_score(preds, target)
......
...@@ -9,16 +9,17 @@ from unittest.mock import patch ...@@ -9,16 +9,17 @@ from unittest.mock import patch
import pytest import pytest
import torch import torch
from pytest import param
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from transformers import AutoTokenizer from transformers import AutoTokenizer, MBartTokenizer
from transformers.testing_utils import require_multigpu from transformers.testing_utils import require_multigpu
from .distillation import distill_main, evaluate_checkpoint from .distillation import distill_main, evaluate_checkpoint
from .finetune import main from .finetune import main
from .pack_dataset import pack_data_dir from .pack_dataset import pack_data_dir
from .run_eval import generate_summaries_or_translations, run_generate from .run_eval import generate_summaries_or_translations, run_generate
from .utils import SummarizationDataset, lmap, load_json from .utils import MBartDataset, Seq2SeqDataset, lmap, load_json
logging.basicConfig(level=logging.DEBUG) logging.basicConfig(level=logging.DEBUG)
...@@ -26,6 +27,7 @@ logging.basicConfig(level=logging.DEBUG) ...@@ -26,6 +27,7 @@ logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger() logger = logging.getLogger()
CUDA_AVAILABLE = torch.cuda.is_available() CUDA_AVAILABLE = torch.cuda.is_available()
CHEAP_ARGS = { CHEAP_ARGS = {
"label_smoothing_eps": 0.2,
"logger_name": "default", "logger_name": "default",
"length_penalty": 0.5, "length_penalty": 0.5,
"cache_dir": "", "cache_dir": "",
...@@ -80,11 +82,11 @@ CHEAP_ARGS = { ...@@ -80,11 +82,11 @@ CHEAP_ARGS = {
def _dump_articles(path: Path, articles: list): def _dump_articles(path: Path, articles: list):
with path.open("w") as f: content = "\n".join(articles)
f.write("\n".join(articles)) Path(path).open("w").writelines(content)
ARTICLES = [" Sam ate lunch today", "Sams lunch ingredients"] ARTICLES = [" Sam ate lunch today.", "Sams lunch ingredients."]
SUMMARIES = ["A very interesting story about what I ate for lunch.", "Avocado, celery, turkey, coffee"] SUMMARIES = ["A very interesting story about what I ate for lunch.", "Avocado, celery, turkey, coffee"]
T5_TINY = "patrickvonplaten/t5-tiny-random" T5_TINY = "patrickvonplaten/t5-tiny-random"
BART_TINY = "sshleifer/bart-tiny-random" BART_TINY = "sshleifer/bart-tiny-random"
...@@ -208,7 +210,7 @@ def test_run_eval_bart(model): ...@@ -208,7 +210,7 @@ def test_run_eval_bart(model):
@pytest.mark.parametrize( @pytest.mark.parametrize(
["model"], [pytest.param(T5_TINY), pytest.param(BART_TINY), pytest.param(MBART_TINY), pytest.param(MARIAN_TINY)] ["model"], [pytest.param(T5_TINY), pytest.param(BART_TINY), pytest.param(MBART_TINY), pytest.param(MARIAN_TINY)],
) )
def test_finetune(model): def test_finetune(model):
args_d: dict = CHEAP_ARGS.copy() args_d: dict = CHEAP_ARGS.copy()
...@@ -260,22 +262,50 @@ def test_pack_dataset(): ...@@ -260,22 +262,50 @@ def test_pack_dataset():
assert orig_paths == new_paths assert orig_paths == new_paths
@pytest.mark.parametrize( def test_mbart_dataset_truncation():
["tok"], [pytest.param(T5_TINY), pytest.param(BART_TINY), pytest.param(MBART_TINY), pytest.param(MARIAN_TINY)] tokenizer = MBartTokenizer.from_pretrained(MBART_TINY)
)
def test_dataset(tok):
tokenizer = AutoTokenizer.from_pretrained(tok)
tmp_dir = make_test_data_dir() tmp_dir = make_test_data_dir()
max_len_source = max(len(tokenizer.encode(a)) for a in ARTICLES) max_len_source = max(len(tokenizer.encode(a)) for a in ARTICLES)
max_len_target = max(len(tokenizer.encode(a)) for a in SUMMARIES) max_len_target = max(len(tokenizer.encode(a)) for a in SUMMARIES)
trunc_target = 4 trunc = 4
train_dataset = SummarizationDataset( src_lang, tgt_lang = "ro_RO", "de_DE" # NOT WHAT IT WAS TRAINED ON
train_dataset = MBartDataset(
tokenizer, tokenizer,
data_dir=tmp_dir, data_dir=tmp_dir,
type_path="train", type_path="train",
max_source_length=20, max_source_length=trunc,
max_target_length=trunc_target, max_target_length=1000, # ignored
tgt_lang="ro_RO", src_lang=src_lang,
tgt_lang=tgt_lang,
)
dataloader = DataLoader(train_dataset, batch_size=2, collate_fn=train_dataset.collate_fn)
for batch in dataloader:
assert isinstance(batch, dict)
assert batch["attention_mask"].shape == batch["input_ids"].shape
# show that articles were trimmed.
assert batch["input_ids"].shape[1] == trunc
# show that targets are the same len
assert batch["decoder_input_ids"].shape[1] == trunc
# check language codes in correct place
assert batch["decoder_input_ids"][0, 0].item() == tokenizer.lang_code_to_id[tgt_lang]
assert batch["decoder_input_ids"][0, -1].item() == tokenizer.eos_token_id
assert batch["input_ids"][0, -2].item() == tokenizer.eos_token_id
assert batch["input_ids"][0, -1].item() == tokenizer.lang_code_to_id[src_lang]
assert max_len_target > trunc # Truncated
assert max_len_source > trunc
break # No need to test every batch
@pytest.mark.parametrize(["tok"], [pytest.param(T5_TINY), pytest.param(BART_TINY), param(MARIAN_TINY)])
def test_summarization_dataset_truncation(tok):
tokenizer = AutoTokenizer.from_pretrained(tok)
tmp_dir = make_test_data_dir()
max_len_source = max(len(tokenizer.encode(a)) for a in ARTICLES)
max_len_target = max(len(tokenizer.encode(a)) for a in SUMMARIES)
trunc_target = 4
train_dataset = Seq2SeqDataset(
tokenizer, data_dir=tmp_dir, type_path="train", max_source_length=20, max_target_length=trunc_target,
) )
dataloader = DataLoader(train_dataset, batch_size=2, collate_fn=train_dataset.collate_fn) dataloader = DataLoader(train_dataset, batch_size=2, collate_fn=train_dataset.collate_fn)
for batch in dataloader: for batch in dataloader:
...@@ -286,3 +316,4 @@ def test_dataset(tok): ...@@ -286,3 +316,4 @@ def test_dataset(tok):
# show that targets were truncated # show that targets were truncated
assert batch["decoder_input_ids"].shape[1] == trunc_target # Truncated assert batch["decoder_input_ids"].shape[1] == trunc_target # Truncated
assert max_len_target > trunc_target # Truncated assert max_len_target > trunc_target # Truncated
break # No need to test every batch
import itertools import itertools
import json import json
import linecache
import os import os
import pickle import pickle
import warnings
from logging import getLogger from logging import getLogger
from pathlib import Path from pathlib import Path
from typing import Callable, Dict, Iterable, List from typing import Callable, Dict, Iterable, List
...@@ -13,50 +15,20 @@ from rouge_score import rouge_scorer, scoring ...@@ -13,50 +15,20 @@ from rouge_score import rouge_scorer, scoring
from sacrebleu import corpus_bleu from sacrebleu import corpus_bleu
from torch import nn from torch import nn
from torch.utils.data import Dataset, Sampler from torch.utils.data import Dataset, Sampler
from tqdm import tqdm
from transformers import BartTokenizer from transformers import BartTokenizer
def encode_file( def encode_line(tokenizer, line, max_length, pad_to_max_length=True, return_tensors="pt"):
tokenizer,
data_path,
max_length,
pad_to_max_length=True,
return_tensors="pt",
overwrite_cache=False,
prefix="",
tok_name="",
):
extra_kw = {"add_prefix_space": True} if isinstance(tokenizer, BartTokenizer) else {} extra_kw = {"add_prefix_space": True} if isinstance(tokenizer, BartTokenizer) else {}
cache_path = Path(f"{data_path}_{tok_name}{max_length}.pt") return tokenizer(
if not overwrite_cache and cache_path.exists(): [line],
try:
examples = torch.load(cache_path)
assert isinstance(examples, list)
return examples
except Exception:
print(f"failed to load from {cache_path}, retokenizing {data_path}")
data_path = Path(data_path)
lns = lmap(str.strip, data_path.open().readlines())
lns = [prefix + text for text in lns]
assert lns, f"found empty file at {data_path}"
examples = []
for text in tqdm(lns, desc=f"Tokenizing {data_path.name}"):
tokenized = tokenizer(
[text],
max_length=max_length, max_length=max_length,
padding="max_length" if pad_to_max_length else None, padding="max_length" if pad_to_max_length else None,
truncation=True, truncation=True,
return_tensors=return_tensors, return_tensors=return_tensors,
**extra_kw, **extra_kw,
) )
assert tokenized.input_ids.shape[1] == max_length
examples.append(tokenized)
torch.save(lmap(dict, examples), cache_path.open("wb"))
return examples
def lmap(f: Callable, x: Iterable) -> List: def lmap(f: Callable, x: Iterable) -> List:
...@@ -80,73 +52,111 @@ def trim_batch( ...@@ -80,73 +52,111 @@ def trim_batch(
return (input_ids[:, keep_column_mask], attention_mask[:, keep_column_mask]) return (input_ids[:, keep_column_mask], attention_mask[:, keep_column_mask])
class SummarizationDataset(Dataset): class Seq2SeqDataset(Dataset):
def __init__( def __init__(
self, self,
tokenizer, tokenizer,
data_dir, data_dir,
max_source_length,
max_target_length,
type_path="train", type_path="train",
max_source_length=1024,
max_target_length=56,
n_obs=None, n_obs=None,
overwrite_cache=False,
prefix="",
src_lang=None, src_lang=None,
tgt_lang=None, tgt_lang=None,
prefix="",
): ):
super().__init__() super().__init__()
# FIXME: the rstrip logic strips all the chars, it seems. self.src_file = Path(data_dir).joinpath(type_path + ".source")
tok_name = tokenizer.__class__.__name__.lower().rstrip("tokenizer") self.tgt_file = Path(data_dir).joinpath(type_path + ".target")
if hasattr(tokenizer, "set_lang") and src_lang is not None: self.src_lens = self.get_char_lens(self.src_file)
tokenizer.set_lang(src_lang) # HACK: only applies to mbart self.max_source_length = max_source_length
self.source = encode_file( self.max_target_length = max_target_length
tokenizer, assert min(self.src_lens) > 0, f"found empty line in {self.src_file}"
os.path.join(data_dir, type_path + ".source"), self.tokenizer = tokenizer
max_source_length, self.prefix = prefix
overwrite_cache=overwrite_cache,
prefix=prefix,
tok_name=tok_name,
)
tgt_path = os.path.join(data_dir, type_path + ".target")
if hasattr(tokenizer, "set_lang"):
assert tgt_lang is not None, "--tgt_lang must be passed to build a translation"
tokenizer.set_lang(tgt_lang) # HACK: only applies to mbart
self.target = encode_file(
tokenizer, tgt_path, max_target_length, overwrite_cache=overwrite_cache, tok_name=tok_name
)
if n_obs is not None: if n_obs is not None:
self.source = self.source[:n_obs] self.src_lens = self.src_lens[:n_obs]
self.target = self.target[:n_obs] self.pad_token_id = self.tokenizer.pad_token_id
self.pad_token_id = tokenizer.pad_token_id self.src_lang = src_lang
self.tgt_lang = tgt_lang
def __len__(self): def __len__(self):
return len(self.source) return len(self.src_lens)
def __getitem__(self, index) -> Dict[str, torch.Tensor]:
index = index + 1 # linecache starts at 1
source_line = self.prefix + linecache.getline(str(self.src_file), index).rstrip("\n")
tgt_line = linecache.getline(str(self.tgt_file), index).rstrip("\n")
assert source_line, f"empty source line for index {index}"
assert tgt_line, f"empty tgt line for index {index}"
source_inputs = encode_line(self.tokenizer, source_line, self.max_source_length)
target_inputs = encode_line(self.tokenizer, tgt_line, self.max_target_length)
source_ids = source_inputs["input_ids"].squeeze()
target_ids = target_inputs["input_ids"].squeeze()
src_mask = source_inputs["attention_mask"].squeeze()
return {
"input_ids": source_ids,
"attention_mask": src_mask,
"decoder_input_ids": target_ids,
}
def __getitem__(self, index): @staticmethod
source_ids = self.source[index]["input_ids"].squeeze() def get_char_lens(data_file):
target_ids = self.target[index]["input_ids"].squeeze() return [len(x) for x in Path(data_file).open().readlines()]
src_mask = self.source[index]["attention_mask"].squeeze()
return {"input_ids": source_ids, "attention_mask": src_mask, "decoder_input_ids": target_ids}
@staticmethod @staticmethod
def trim_seq2seq_batch(batch, pad_token_id): def trim_seq2seq_batch(batch, pad_token_id) -> tuple:
y = trim_batch(batch["decoder_input_ids"], pad_token_id) y = trim_batch(batch["decoder_input_ids"], pad_token_id)
source_ids, source_mask = trim_batch(batch["input_ids"], pad_token_id, attention_mask=batch["attention_mask"]) source_ids, source_mask = trim_batch(batch["input_ids"], pad_token_id, attention_mask=batch["attention_mask"])
return source_ids, source_mask, y return source_ids, source_mask, y
def collate_fn(self, batch) -> dict: def collate_fn(self, batch) -> Dict[str, torch.Tensor]:
input_ids = torch.stack([x["input_ids"] for x in batch]) input_ids = torch.stack([x["input_ids"] for x in batch])
masks = torch.stack([x["attention_mask"] for x in batch]) masks = torch.stack([x["attention_mask"] for x in batch])
target_ids = torch.stack([x["decoder_input_ids"] for x in batch]) target_ids = torch.stack([x["decoder_input_ids"] for x in batch])
pad_token_id = self.pad_token_id pad_token_id = self.pad_token_id
y = trim_batch(target_ids, pad_token_id) y = trim_batch(target_ids, pad_token_id)
source_ids, source_mask = trim_batch(input_ids, pad_token_id, attention_mask=masks) source_ids, source_mask = trim_batch(input_ids, pad_token_id, attention_mask=masks)
batch = {"input_ids": source_ids, "attention_mask": source_mask, "decoder_input_ids": y} batch = {
"input_ids": source_ids,
"attention_mask": source_mask,
"decoder_input_ids": y,
}
return batch return batch
def make_sortish_sampler(self, batch_size): def make_sortish_sampler(self, batch_size):
lens = [x["input_ids"].ne(self.pad_token_id).sum() for x in self.source] return SortishSampler(self.src_lens, batch_size)
return SortishSampler(lens, batch_size)
class MBartDataset(Seq2SeqDataset):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if self.max_source_length != self.max_target_length:
warnings.warn(
f"Mbart will ignore max_target_length = {self.max_target_length} and use {self.max_source_length} for both sides."
)
def __getitem__(self, index) -> Dict[str, str]:
index = index + 1 # linecache starts at 1
source_line = self.prefix + linecache.getline(str(self.src_file), index).rstrip("\n")
tgt_line = linecache.getline(str(self.tgt_file), index).rstrip("\n")
assert source_line, f"empty source line for index {index}"
assert tgt_line, f"empty tgt line for index {index}"
return {
"tgt_texts": source_line,
"src_texts": tgt_line,
}
def collate_fn(self, batch) -> Dict[str, torch.Tensor]:
batch_encoding = self.tokenizer.prepare_translation_batch(
[x["src_texts"] for x in batch],
src_lang=self.src_lang,
tgt_texts=[x["tgt_texts"] for x in batch],
tgt_lang=self.tgt_lang,
max_length=self.max_source_length,
)
return batch_encoding.data
class SortishSampler(Sampler): class SortishSampler(Sampler):
......
...@@ -118,12 +118,7 @@ class MBartTokenizer(XLMRobertaTokenizer): ...@@ -118,12 +118,7 @@ class MBartTokenizer(XLMRobertaTokenizer):
self.fairseq_tokens_to_ids.update(self.lang_code_to_id) self.fairseq_tokens_to_ids.update(self.lang_code_to_id)
self.fairseq_ids_to_tokens = {v: k for k, v in self.fairseq_tokens_to_ids.items()} self.fairseq_ids_to_tokens = {v: k for k, v in self.fairseq_tokens_to_ids.items()}
self._additional_special_tokens = list(self.lang_code_to_id.keys()) self._additional_special_tokens = list(self.lang_code_to_id.keys())
self.reset_special_tokens() self.set_src_lang_special_tokens(kwargs.get("src_lang", "en_XX"))
def reset_special_tokens(self) -> None:
"""Reset the special tokens to the source lang setting. No prefix and suffix=[eos, cur_lang_code]."""
self.prefix_tokens = []
self.suffix_tokens = [self.eos_token_id, self.cur_lang_code]
def build_inputs_with_special_tokens( def build_inputs_with_special_tokens(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
...@@ -183,12 +178,6 @@ class MBartTokenizer(XLMRobertaTokenizer): ...@@ -183,12 +178,6 @@ class MBartTokenizer(XLMRobertaTokenizer):
return prefix_ones + ([0] * len(token_ids_0)) + suffix_ones return prefix_ones + ([0] * len(token_ids_0)) + suffix_ones
return prefix_ones + ([0] * len(token_ids_0)) + ([0] * len(token_ids_1)) + suffix_ones return prefix_ones + ([0] * len(token_ids_0)) + ([0] * len(token_ids_1)) + suffix_ones
def set_lang(self, lang: str) -> None:
"""Set the current language code in order to call tokenizer properly."""
self.cur_lang_code = self.lang_code_to_id[lang]
self.prefix_tokens = [self.cur_lang_code]
self.suffix_tokens = [self.eos_token_id]
def prepare_translation_batch( def prepare_translation_batch(
self, self,
src_texts: List[str], src_texts: List[str],
...@@ -215,7 +204,7 @@ class MBartTokenizer(XLMRobertaTokenizer): ...@@ -215,7 +204,7 @@ class MBartTokenizer(XLMRobertaTokenizer):
""" """
if max_length is None: if max_length is None:
max_length = self.max_len max_length = self.max_len
self.cur_lang_code = self.lang_code_to_id[src_lang] self.set_src_lang_special_tokens(src_lang)
model_inputs: BatchEncoding = self( model_inputs: BatchEncoding = self(
src_texts, src_texts,
add_special_tokens=True, add_special_tokens=True,
...@@ -227,7 +216,7 @@ class MBartTokenizer(XLMRobertaTokenizer): ...@@ -227,7 +216,7 @@ class MBartTokenizer(XLMRobertaTokenizer):
) )
if tgt_texts is None: if tgt_texts is None:
return model_inputs return model_inputs
self.set_lang(tgt_lang) self.set_tgt_lang_special_tokens(tgt_lang)
decoder_inputs: BatchEncoding = self( decoder_inputs: BatchEncoding = self(
tgt_texts, tgt_texts,
add_special_tokens=True, add_special_tokens=True,
...@@ -239,6 +228,18 @@ class MBartTokenizer(XLMRobertaTokenizer): ...@@ -239,6 +228,18 @@ class MBartTokenizer(XLMRobertaTokenizer):
) )
for k, v in decoder_inputs.items(): for k, v in decoder_inputs.items():
model_inputs[f"decoder_{k}"] = v model_inputs[f"decoder_{k}"] = v
self.cur_lang_code = self.lang_code_to_id[src_lang]
self.reset_special_tokens() # sets to src_lang self.set_src_lang_special_tokens(src_lang) # sets to src_lang
return model_inputs return model_inputs
def set_src_lang_special_tokens(self, src_lang) -> None:
"""Reset the special tokens to the source lang setting. No prefix and suffix=[eos, cur_lang_code]."""
self.cur_lang_code = self.lang_code_to_id[src_lang]
self.prefix_tokens = []
self.suffix_tokens = [self.eos_token_id, self.cur_lang_code]
def set_tgt_lang_special_tokens(self, lang: str) -> None:
"""Reset the special tokens to the target language setting. Prefix [tgt_lang_code], suffix =[eos]."""
self.cur_lang_code = self.lang_code_to_id[lang]
self.prefix_tokens = [self.cur_lang_code]
self.suffix_tokens = [self.eos_token_id]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment