Commit 4c3ac4a7 authored by Rémi Louf's avatar Rémi Louf
Browse files

here's one big commit

parent 932543f7
......@@ -393,7 +393,8 @@ This fine-tuned model is available as a checkpoint under the reference
## Seq2seq model fine-tuning
Based on the script [`run_seq2seq_finetuning.py`](https://github.com/huggingface/transformers/blob/master/examples/run_seq2seq_finetuning.py).
Based on the script
[`run_summarization_finetuning.py`](https://github.com/huggingface/transformers/blob/master/examples/run_summarization_finetuning.py).
Before running this script you should download **both** CNN and Daily Mail
datasets from [Kyunghyun Cho's website](https://cs.nyu.edu/~kcho/DMQA/) (the
......@@ -412,7 +413,7 @@ archive.
```bash
export DATA_PATH=/path/to/dataset/
python run_seq2seq_finetuning.py \
python run_summarization_finetuning.py \
--output_dir=output \
--model_type=bert2bert \
--model_name_or_path=bert2bert \
......
# coding=utf-8
# Copyright 2018 The Microsoft Reseach team and The HuggingFace Inc. team.
# Copyright (c) 2018 Microsoft and The HuggingFace Inc. All rights reserved.
# Copyright 2019 The HuggingFace Inc. team.
# Copyright (c) 2019 The HuggingFace Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -18,18 +18,21 @@
import argparse
from collections import deque
import logging
import os
import pickle
import random
import os
import sys
import numpy as np
from tqdm import tqdm, trange
import torch
from torch.utils.data import Dataset, RandomSampler
from torch.optim import Adam
from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler
from transformers import AutoTokenizer, Model2Model
from transformers import AutoTokenizer, PreTrainedSeq2seq, Model2Model
logger = logging.getLogger(__name__)
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
def set_seed(args):
......@@ -61,7 +64,7 @@ class TextDataset(Dataset):
def __init__(self, tokenizer, prefix="train", data_dir="", block_size=512):
assert os.path.isdir(data_dir)
# Load features that have already been computed if present
# Load the features that have already been computed, if any
cached_features_file = os.path.join(
data_dir, "cached_lm_{}_{}".format(block_size, prefix)
)
......@@ -72,12 +75,11 @@ class TextDataset(Dataset):
return
logger.info("Creating features from dataset at %s", data_dir)
self.examples = []
datasets = ["cnn", "dailymail"]
self.examples = {"source": [], "target": []}
for dataset in datasets:
path_to_stories = os.path.join(data_dir, dataset, "stories")
assert os.path.isdir(path_to_stories)
story_filenames_list = os.listdir(path_to_stories)
for story_filename in story_filenames_list:
path_to_story = os.path.join(path_to_stories, story_filename)
......@@ -85,19 +87,19 @@ class TextDataset(Dataset):
continue
with open(path_to_story, encoding="utf-8") as source:
try:
raw_story = source.read()
story, summary = process_story(raw_story)
except IndexError: # skip ill-formed stories
story_lines, summary_lines = process_story(raw_story)
if len(summary_lines) == 0 or len(story_lines) == 0:
continue
story = tokenizer.encode(story)
story_seq = _fit_to_block_size(story, block_size)
summary = tokenizer.encode(summary)
summary_seq = _fit_to_block_size(summary, block_size)
story_token_ids, summary_token_ids = _encode_for_summarization(
story_lines, summary_lines, tokenizer
)
story_seq = _fit_to_block_size(story_token_ids, block_size)
self.examples["source"].append(story_seq)
self.examples.append((story_seq, summary_seq))
summary_seq = _fit_to_block_size(summary_token_ids, block_size)
self.examples["summary"].append(summary_seq)
logger.info("Saving features into cache file %s", cached_features_file)
with open(cached_features_file, "wb") as sink:
......@@ -107,7 +109,10 @@ class TextDataset(Dataset):
return len(self.examples)
def __getitem__(self, items):
return torch.tensor(self.examples[items])
return (
torch.tensor(self.examples["source"][items]),
torch.tensor(self.examples["target"][items]),
)
def process_story(raw_story):
......@@ -119,33 +124,55 @@ def process_story(raw_story):
Raises:
IndexError: If the stoy is empty or contains no highlights.
"""
file_lines = list(
nonempty_lines = list(
filter(lambda x: len(x) != 0, [line.strip() for line in raw_story.split("\n")])
)
# for some unknown reason some lines miss a period, add it
file_lines = [_add_missing_period(line) for line in file_lines]
nonempty_lines = [_add_missing_period(line) for line in nonempty_lines]
# gather article lines
story_lines = []
lines = deque(file_lines)
lines = deque(nonempty_lines)
while True:
try:
element = lines.popleft()
if element.startswith("@highlight"):
break
story_lines.append(element)
except IndexError as ie: # if "@highlight" absent from file
raise ie
except IndexError:
# if "@highlight" is absent from the file we pop
# all elements until there is None.
return story_lines, []
# gather summary lines
highlights_lines = list(filter(lambda t: not t.startswith("@highlight"), lines))
summary_lines = list(filter(lambda t: not t.startswith("@highlight"), lines))
# join the lines
story = " ".join(story_lines)
summary = " ".join(highlights_lines)
return story_lines, summary_lines
return story, summary
def _encode_for_summarization(story_lines, summary_lines, tokenizer):
""" Encode the story and summary lines, and join them
as specified in [1] by using `[SEP] [CLS]` tokens to separate
sentences.
"""
story_lines_token_ids = [
tokenizer.add_special_tokens_single_sequence(tokenizer.encode(line))
for line in story_lines
]
summary_lines_token_ids = [
tokenizer.add_special_tokens_single_sequence(tokenizer.encode(line))
for line in summary_lines
]
story_token_ids = [
token for sentence in story_lines_token_ids for token in sentence
]
summary_token_ids = [
token for sentence in summary_lines_token_ids for token in sentence
]
return story_token_ids, summary_token_ids
def _add_missing_period(line):
......@@ -170,8 +197,11 @@ def _fit_to_block_size(sequence, block_size):
def mask_padding_tokens(sequence):
""" Replace the padding token with -1 values """
return [s if s != 0 else -1 for s in sequence]
""" Padding token, encoded as 0, are represented by the value -1 in the
masks """
padded = sequence.clone()
padded[padded == 0] = -1
return padded
def load_and_cache_examples(args, tokenizer):
......@@ -179,81 +209,181 @@ def load_and_cache_examples(args, tokenizer):
return dataset
def compute_token_type_ids(batch, separator_token_id):
""" Segment embeddings as described in [1]
The values {0,1} were found in the repository [2].
Attributes:
batch: torch.Tensor, size [batch_size, block_size]
Batch of input.
separator_token_id: int
The value of the token that separates the segments.
[1] Liu, Yang, and Mirella Lapata. "Text summarization with pretrained encoders."
arXiv preprint arXiv:1908.08345 (2019).
[2] https://github.com/nlpyang/PreSumm (/src/prepro/data_builder.py, commit fac1217)
"""
batch_embeddings = []
sentence_num = 0
for sequence in batch:
embeddings = []
for s in sequence:
if s == separator_token_id:
sentence_num += 1
embeddings.append(sentence_num % 2)
batch_embeddings.append(embeddings)
return torch.tensor(batch_embeddings)
# ----------
# Optimizers
# ----------
class BertSumOptimizer(object):
""" Specific optimizer for BertSum.
As described in [1], the authors fine-tune BertSum for abstractive
summarization using two Adam Optimizers with different warm-up steps and
learning rate. They also use a custom learning rate scheduler.
[1] Liu, Yang, and Mirella Lapata. "Text summarization with pretrained encoders."
arXiv preprint arXiv:1908.08345 (2019).
"""
def __init__(self, model, lr, warmup_steps, beta_1=0.99, beta_2=0.999, eps=1e-9):
self.encoder = model.encoder
self.decoder = model.decoder
self.lr = lr
self.warmup_steps = warmup_steps
self.optimizers = {
"encoder": Adam(
model.encoder.parameters(),
lr=lr["encoder"],
betas=(beta_1, beta_2),
eps=eps,
),
"decoder": Adam(
model.decoder.parameters(),
lr=lr["decoder"],
betas=(beta_1, beta_2),
eps=eps,
),
}
self._step = 0
def _update_rate(self, stack):
return self.lr[stack] * min(
self._step ** (-0.5), self._step * self.warmup_steps[stack] ** (-0.5)
)
def zero_grad(self):
self.optimizer_decoder.zero_grad()
self.optimizer_encoder.zero_grad()
def step(self):
self._step += 1
for stack, optimizer in self.optimizers.items():
new_rate = self._update_rate(stack)
for param_group in optimizer.param_groups:
param_group["lr"] = new_rate
optimizer.step()
# ------------
# Train
# ------------
def train(args, train_dataset, model, tokenizer):
def train(args, model, tokenizer):
""" Fine-tune the pretrained model on the corpus. """
set_seed(args)
# Prepare the data loading
args.train_bach_size = 1
# Load the data
args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
train_dataset = load_and_cache_examples(args, tokenizer)
train_sampler = RandomSampler(train_dataset)
train_dataloader = DataLoader(
train_dataset, sampler=train_sampler, batch_size=args.train_bach_size
train_dataset, sampler=train_sampler, batch_size=args.train_batch_size
)
# Prepare the optimizer and schedule (linear warmup and decay)
no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
{
"params": [
p
for n, p in model.named_parameters()
if not any(nd in n for nd in no_decay)
],
"weight_decay": args.weight_decay,
},
{
"params": [
p
for n, p in model.named_parameters()
if any(nd in n for nd in no_decay)
],
"weight_decay": 0.0,
},
]
optimizer = AdamW(
optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon
# Training schedule
if args.max_steps > 0:
t_total = args.max_steps
args.num_train_epochs = t_total // (
len(train_dataloader) // args.gradient_accumulation_steps + 1
)
scheduler = WarmupLinearSchedule(
optimizer, warmup_steps=args.warmup_steps, t_total=t_total
else:
t_total = (
len(train_dataloader)
// args.gradient_accumulation_steps
* args.num_train_epochs
)
# Prepare the optimizer
lr = {"encoder": 0.002, "decoder": 0.2}
warmup_steps = {"encoder": 20000, "decoder": 10000}
optimizer = BertSumOptimizer(model, lr, warmup_steps)
# Train
logger.info("***** Running training *****")
logger.info(" Num examples = %d", len(train_dataset))
logger.info(" Num Epochs = %d", args.num_train_epochs)
logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d",
args.train_batch_size
* args.gradient_accumulation_steps
* (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
logger.info(
" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size
)
logger.info(
" Total train batch size (w. parallel, distributed & accumulation) = %d",
args.train_batch_size * args.gradient_accumulation_steps
# * (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
)
logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
logger.info(" Total optimization steps = %d", t_total)
global_step = 0
tr_loss, logging_loss = 0.0, 0.0
model.zero_grad()
train_iterator = trange(args.num_train_epochs, desc="Epoch", disable=True)
set_seed(args)
global_step = 0
tr_loss = 0.0
for _ in train_iterator:
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=True)
for step, batch in enumerate(epoch_iterator):
source = ([s for s, _ in batch]).to(args.device)
target = ([t for _, t in batch]).to(args.device)
source, target = batch
token_type_ids = compute_token_type_ids(source, tokenizer.cls_token_id)
labels_src = mask_padding_tokens(source)
labels_tgt = mask_padding_tokens(target)
source = source.to(args.device)
target = target.to(args.device)
token_type_ids = token_type_ids.to(args.device)
labels_src = labels_src.to(args.device)
labels_tgt = labels_tgt.to(args.device)
model.train()
outputs = model(source, target, decoder_lm_labels=mask_padding_tokens(target))
outputs = model(
source,
target,
token_type_ids=token_type_ids,
decoder_encoder_attention_mask=labels_src,
decoder_attention_mask=labels_tgt,
decoder_lm_labels=labels_tgt,
decoder_initialize_randomly=True,
)
loss = outputs[0]
print(loss)
if args.gradient_accumulation_steps > 1:
loss /= args.gradient_accumulation_steps
loss.backward()
tr_loss += loss.item()
if (step + 1) % args.gradient_accumulation_steps == 0:
torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
optimizer.step()
scheduler.step()
model.zero_grad()
global_step += 1
......@@ -268,6 +398,68 @@ def train(args, train_dataset, model, tokenizer):
return global_step, tr_loss / global_step
# ------------
# Train
# ------------
def evaluate(args, model, tokenizer, prefix=""):
set_seed(args)
args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
eval_dataset = load_and_cache_examples(args, tokenizer, evaluate=True)
eval_sampler = SequentialSampler(eval_dataset)
eval_dataloader = DataLoader(
eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size
)
logger.info("***** Running evaluation {} *****".format(prefix))
logger.info(" Num examples = %d", len(eval_dataset))
logger.info(" Batch size = %d", args.eval_batch_size)
eval_loss = 0.0
nb_eval_steps = 0
model.eval()
for batch in tqdm(eval_dataloader, desc="Evaluating"):
source, target = batch
labels_src = mask_padding_tokens(source)
labels_tgt = mask_padding_tokens(target)
source.to(args.device)
target.to(args.device)
labels_src.to(args.device)
labels_tgt.to(args.device)
with torch.no_grad():
outputs = model(
source,
target,
decoder_encoder_attention_mask=labels_src,
decoder_attention_mask=labels_tgt,
decoder_lm_labels=labels_tgt,
)
lm_loss = outputs[0]
eval_loss += lm_loss.mean().item()
nb_eval_steps += 1
eval_loss = eval_loss / nb_eval_steps
perplexity = torch.exp(torch.tensor(eval_loss))
result = {"perplexity": perplexity}
# Save the evaluation's results
output_eval_file = os.path.join(args.output_dir, "eval_results.txt")
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
with open(output_eval_file, "w") as writer:
logger.info("***** Eval results {} *****".format(prefix))
for key in sorted(result.keys()):
logger.info(" %s = %s", key, str(result[key]))
writer.write("%s = %s\n" % (key, str(result[key])))
return result
def main():
parser = argparse.ArgumentParser()
......@@ -289,7 +481,23 @@ def main():
# Optional parameters
parser.add_argument(
"--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer."
"--gradient_accumulation_steps",
type=int,
default=1,
help="Number of updates steps to accumulate before performing a backward/update pass.",
)
parser.add_argument(
"--do_evaluate",
type=bool,
default=False,
help="Run model evaluation on out-of-sample data.",
)
parser.add_argument("--do_train", type=bool, default=False, help="Run training.")
parser.add_argument(
"--do_overwrite_output_dir",
type=bool,
default=False,
help="Whether to overwrite the output dir.",
)
parser.add_argument(
"--model_name_or_path",
......@@ -303,12 +511,6 @@ def main():
type=str,
help="The decoder architecture to be fine-tuned.",
)
parser.add_argument(
"--learning_rate",
default=5e-5,
type=float,
help="The initial learning rate for Adam.",
)
parser.add_argument(
"--max_grad_norm", default=1.0, type=float, help="Max gradient norm."
)
......@@ -318,43 +520,100 @@ def main():
type=int,
help="If > 0: set total number of training steps to perform. Override num_train_epochs.",
)
parser.add_argument(
"--to_cpu", default=False, type=bool, help="Whether to force training on CPU."
)
parser.add_argument(
"--num_train_epochs",
default=1,
type=int,
help="Total number of training epochs to perform.",
)
parser.add_argument("--seed", default=42, type=int)
parser.add_argument(
"--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps."
)
parser.add_argument(
"--weight_decay", default=0.0, type=float, help="Weight deay if we apply some."
"--per_gpu_train_batch_size",
default=4,
type=int,
help="Batch size per GPU/CPU for training.",
)
parser.add_argument("--seed", default=42, type=int)
args = parser.parse_args()
if args.model_type != "bert":
if (
os.path.exists(args.output_dir)
and os.listdir(args.output_dir)
and args.do_train
and not args.do_overwrite_output_dir
):
raise ValueError(
"Only the BERT architecture is currently supported for seq2seq."
"Output directory ({}) already exists and is not empty. Use --do_overwrite_output_dir to overwrite.".format(
args.output_dir
)
)
# Set up training device
# device = torch.device("cpu")
# Set seed
set_seed(args)
if args.to_cpu or not torch.cuda.is_available():
args.device = torch.device("cpu")
args.n_gpu = 0
else:
args.device = torch.device("cuda")
args.n_gpu = torch.cuda.device_count()
# Load pretrained model and tokenizer
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
model = Model2Model.from_pretrained(args.model_name_or_path)
# model.to(device)
# Setup logging
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
logger.warning(
"Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
0,
args.device,
args.n_gpu,
False,
False,
)
logger.info("Training/evaluation parameters %s", args)
# Training
train_dataset = load_and_cache_examples(args, tokenizer)
global_step, tr_loss = train(args, train_dataset, model, tokenizer)
# logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
# Train the model
model.to(args.device)
if args.do_train:
global_step, tr_loss = train(args, model, tokenizer)
logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
logger.info("Saving model checkpoint to %s", args.output_dir)
# Save a trained model, configuration and tokenizer using `save_pretrained()`.
# They can then be reloaded using `from_pretrained()`
model_to_save = (
model.module if hasattr(model, "module") else model
) # Take care of distributed/parallel training
model_to_save.save_pretrained(args.output_dir)
tokenizer.save_pretrained(args.output_dir)
torch.save(args, os.path.join(args.output_dir, "training_arguments.bin"))
# Evaluate the model
results = {}
if args.do_evaluate:
checkpoints = []
logger.info("Evaluate the following checkpoints: %s", checkpoints)
for checkpoint in checkpoints:
encoder_checkpoint = os.path.join(checkpoint, "encoder")
decoder_checkpoint = os.path.join(checkpoint, "decoder")
model = PreTrainedSeq2seq.from_pretrained(
encoder_checkpoint, decoder_checkpoint
)
model.to(args.device)
results = "placeholder"
return results
if __name__ == "__main__":
......
......@@ -14,7 +14,7 @@
# limitations under the License.
import unittest
from run_seq2seq_finetuning import _fit_to_block_size, process_story
from run_summarization_finetuning import _fit_to_block_size, process_story
class DataLoaderTest(unittest.TestCase):
......@@ -43,15 +43,16 @@ class DataLoaderTest(unittest.TestCase):
raw_story = """It was the year of Our Lord one thousand seven hundred and
seventy-five.\n\nSpiritual revelations were conceded to England at that
favoured period, as at this."""
with self.assertRaises(IndexError):
process_story(raw_story)
_, summary = process_story(raw_story)
self.assertEqual(summary, [])
def test_process_empty_story(self):
""" An empty story should also raise and exception.
"""
raw_story = ""
with self.assertRaises(IndexError):
process_story(raw_story)
story, summary = process_story(raw_story)
self.assertEqual(story, [])
self.assertEqual(summary, [])
def test_story_with_missing_period(self):
raw_story = (
......@@ -59,17 +60,16 @@ class DataLoaderTest(unittest.TestCase):
"seventy-five\n\nSpiritual revelations were conceded to England "
"at that favoured period, as at this.\n@highlight\n\nIt was the best of times"
)
story, summary = process_story(raw_story)
story_lines, summary_lines = process_story(raw_story)
expected_story = (
"It was the year of Our Lord one thousand seven hundred and "
"seventy-five. Spiritual revelations were conceded to England at that "
"favoured period, as at this."
)
self.assertEqual(expected_story, story)
expected_story_lines = [
"It was the year of Our Lord one thousand seven hundred and seventy-five.",
"Spiritual revelations were conceded to England at that favoured period, as at this.",
]
self.assertEqual(expected_story_lines, story_lines)
expected_summary = "It was the best of times."
self.assertEqual(expected_summary, summary)
expected_summary_lines = ["It was the best of times."]
self.assertEqual(expected_summary_lines, summary_lines)
if __name__ == "__main__":
......
......@@ -87,7 +87,7 @@ if is_torch_available():
from .modeling_distilbert import (DistilBertForMaskedLM, DistilBertModel,
DistilBertForSequenceClassification, DistilBertForQuestionAnswering,
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP)
from .modeling_seq2seq import Model2Model
from .modeling_seq2seq import PreTrainedSeq2seq, Model2Model
# Optimization
from .optimization import (AdamW, ConstantLRSchedule, WarmupConstantSchedule, WarmupCosineSchedule,
......
# coding=utf-8
# Copyright (c) 2019 Yang Liu
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
"""
A general wrapper around models with LM heads to generate sequences
using beam search.
"""
import torch
from torch import nn
class ModelWithBeamSearch(nn.Module):
def __init__(
self,
model,
beam_size,
start_token_id,
end_token_id,
pad_token_id,
min_length,
max_length,
alpha,
block_trigram=True,
):
"""
Attributes:
mask_word_id: token id that corresponds to the mask
"""
super(ModelWithBeamSearch, self).__init__()
self.model = model
self.beam_size = beam_size
self.start_token_id = start_token_id
self.end_token_id = end_token_id
self.pad_token_id = pad_token_id
self.min_length = min_length
self.max_length = max_length
self.alpha = alpha
self.block_trigram = block_trigram
def forward(self, input_ids, **kwargs):
# Separate the encoder- and decoder- specific kwargs. A kwarg is
# decoder-specific it the key starts with `decoder_`
kwargs_encoder = {
argument: value
for argument, value in kwargs.items()
if not argument.startswith("decoder_")
}
kwargs_decoder = {
argument[len("decoder_"):]: value
for argument, value in kwargs.items()
if argument.startswith("decoder_")
}
batch_size, _ = input_ids.size(0)
# Variables that keep track of the status of the search
hypotheses = [[] for _ in range(batch_size)]
batch_offset = torch.arange(batch_size, dtype=torch.long)
beam_offset = torch.arange(
0,
batch_size * self.beam_size,
step=self.beam_size,
dtype=torch.long,
)
growing_beam = torch.full(
(batch_size * self.beam_size, 1),
self.start_token_id,
dtype=torch.long,
)
topk_log_probabilities = torch.tensor(
[0.0] + [float("-inf")] * (self.beam_size - 1),
dtype=torch.float,
).repeat(batch_size)
# Forward pass on the encoder
encoder_outputs = self.encoder(input_ids, kwargs_encoder)
kwargs_decoder["encoder_hidden_states"] = tile(
encoder_outputs, self.beam_size, dim=0
)
results = {}
results["predictions"] = [[] for _ in batch_size]
results["scores"] = [[] for _ in batch_size]
for step in range(self.max_length):
decoder_input = growing_beam[:, -1]
outputs = self.decoder(decoder_input, kwargs_decoder)
log_probabilities = torch.nn.functional.log_softmax(outputs[1])
vocab_size = log_probabilities.size(-1)
# The batch size changes as some beams finish so we define:
_B = log_probabilities.size(0) // self.beam_size
# Multiply each beam probability with the probability of the
# next token (conditioned on the words in the beam).
log_probabilities += topk_log_probabilities.view(-1, 1)
# if the beam has not attained the minimum required length we
# make the end token arbitrarily unlikely.
if step < self.min_length:
log_probabilities[self.end_token_id] = -1e20
# Remove repeating tri-grams
if(self.args.block_trigram):
if(step + 1 > 3):
for i in range(_B * self.beam_size):
tokens = [t for t in growing_beam[i]]
trigrams = [(tokens[i-1], tokens[i], tokens[i+1]) for i in range(1, len(words) - 1)]
last_trigram = tuple(trigrams[-1])
if last_trigram in trigrams[:-1]:
log_probabilities[i] = -1e20
# Find the `beam_size` (previous_beam + token) combinations with
# the highest score
topk_log_probabilities, topk_ids = log_probabilities.topk(
log_probabilities.view(_B, self.beam_size * vocab_size),
self.beam_size,
dim=1
)
# Apply the length penalty. The +1 accounts for the [EOS] token
# that will be added if the beam ends.
length_penalty = ((5.0 + (step + 1)) / 6.0) ** self.alpha
topk_scores = topk_log_probabilities / length_penalty
# Retrieve the corresponding respective beam and token id
# topk_token_ids[i] will be added to topk_beam_ids[i]
topk_beam_ids = topk_ids.div(vocab_size)
topk_token_ids = topk_ids.fmod(vocab_size)
# Retrieve the row index of the surviving beams in the original
# view of the log_probabilities tensor
surviving_beams_rows = (
topk_beam_ids + beam_offset[:_B].view(-1, 1)
).view(-1)
# Append the last predictions
growing_beam = torch.cat(
[
growing_beam.index_select(0, surviving_beams_rows),
topk_token_ids.view(-1, 1),
],
1,
)
# Check if any of the beam searches has ended during this
# growth step. Also if top beam (most probable) has ended
# for one element of the batch.
is_finished = topk_token_ids.eq(self.end_token_id)
if step + 1 == self.max_length:
is_finished.fill_(1)
is_top_beam_finished = is_finished[:, 0].eq(1)
# Save the finished searches
if is_finished.any():
predictions = growing_beam.view(-1, self.beam_size, growing_beam.size(1))
for i in range(is_finished.size(0)):
if is_top_beam_finished[i]:
is_finished[i].fill_(1)
finished_hyp = is_finished[i].nonzero().view(-1)
# Store finished hypotheses for this batch.
b = batch_offset[i]
for j in finished_hyp:
hypotheses[b].append((topk_scores[i, j], predictions[i, j, :]))
# If the batch reached the end, save the best hypotheses
# in terms of length-penalized score.
if is_top_beam_finished[i]:
best_hyp = sorted(
hypotheses[b], key=lambda x: x[0], reverse=True
)
best_score, best_prediction = best_hyp[0]
results["scores"][b].append(best_score)
results["predictions"][b].append(best_prediction)
non_finished = is_top_beam_finished.eq(0).nonzero().view(-1)
if len(non_finished) == 0:
break
# Remove finished batches for the next step.
topk_log_probabilities = topk_log_probabilities.index_select(0, non_finished)
batch_offset = batch_offset.index_select(0, non_finished)
growing_beam = predictions.index_select(0, non_finished).view(
-1, growing_beam.size(-1)
)
# Re-order the state for the next pass
surviving_beams_rows = surviving_beams_rows.index_select(0, non_finished)
kwargs_decoder["encoder_hidden_states"] = kwargs_decoder[
"encoder_hidden_states"
].index_select(0, surviving_beams_rows)
return results
def tile(x, count, dim=0):
"""
Tiles `x` along dimension `dim` `count` times.
Example:
>> ex = torch.tensor([1,2],[3,4])
>> tile(ex, 2, 0)
torch.Tensor([[1,2],[1,2],[3,4],[3,4]])
"""
perm = list(range(len(x.size())))
if dim != 0:
perm[0], perm[dim] = perm[dim], perm[0]
x = x.permute(perm).contiguous()
out_size = list(x.size())
out_size[0] *= count
batch = x.size(0)
x = (
x.view(batch, -1)
.transpose(0, 1)
.repeat(count, 1)
.transpose(0, 1)
.contiguous()
.view(*out_size)
)
if dim != 0:
x = x.permute(perm).contiguous()
return x
......@@ -646,7 +646,7 @@ class BertModel(BertPreTrainedModel):
if attention_mask.dim() == 2:
if self.config.is_decoder:
batch_size, seq_length = input_ids.size()
seq_ids = torch.arange(seq_length)
seq_ids = torch.arange(seq_length, device=input_ids.device)
causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
else:
......@@ -660,6 +660,13 @@ class BertModel(BertPreTrainedModel):
extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
# If a 2D encoder attention mask is provided for the cross-attention
# we need to make broadcastabe to [batch_size, num_heads, seq_length, seq_length]
if encoder_attention_mask is not None:
encoder_attention_mask = encoder_attention_mask[:, None, None, :]
encoder_attention_mask = encoder_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
encoder_attention_mask = (1.0 - encoder_attention_mask) * -10000.0
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
......@@ -819,7 +826,7 @@ class BertForMaskedLM(BertPreTrainedModel):
self.bert.embeddings.word_embeddings)
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
masked_lm_labels=None, lm_labels=None, encoder_hidden_states=None, encoder_attention_mask=None):
masked_lm_labels=None, encoder_hidden_states=None, encoder_attention_mask=None, lm_labels=None, ):
outputs = self.bert(input_ids,
attention_mask=attention_mask,
......@@ -838,11 +845,8 @@ class BertForMaskedLM(BertPreTrainedModel):
# 1. If a tensor that contains the indices of masked labels is provided,
# the cross-entropy is the MLM cross-entropy that measures the likelihood
# of predictions for masked words.
# 2. If encoder hidden states are provided we are in a causal situation where we
# 2. If `lm_label` is provided we are in a causal scenario where we
# try to predict the next word for each input in the encoder.
if masked_lm_labels is not None and lm_labels is not None:
raise AttributeError("Masked LM training with an encoder-decoder is not supported.")
if masked_lm_labels is not None:
loss_fct = CrossEntropyLoss(ignore_index=-1) # -1 index = padding token
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1))
......@@ -851,9 +855,9 @@ class BertForMaskedLM(BertPreTrainedModel):
if lm_labels is not None:
# we are doing next-token prediction; shift prediction scores and input ids by one
prediction_scores = prediction_scores[:, :-1, :]
lm_labels = lm_labels[:, 1:, :]
lm_labels = lm_labels[:, 1:]
loss_fct = CrossEntropyLoss(ignore_index=-1)
seq2seq_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), lm_labels.view(-1))
seq2seq_loss = loss_fct(prediction_scores.reshape(-1, self.config.vocab_size), lm_labels.reshape(-1))
outputs = (seq2seq_loss,) + outputs
return outputs # (mlm_or_seq2seq_loss), prediction_scores, (hidden_states), (attentions)
......
......@@ -17,13 +17,12 @@
from __future__ import absolute_import, division, print_function, unicode_literals
import logging
import os
import torch
from torch import nn
from .file_utils import add_start_docstrings
from .modeling_auto import AutoModel, AutoModelWithLMHead
from .modeling_utils import PreTrainedModel, SequenceSummary
logger = logging.getLogger(__name__)
......@@ -43,7 +42,13 @@ class PreTrainedSeq2seq(nn.Module):
self.decoder = decoder
@classmethod
def from_pretrained(cls, encoder_pretrained_model_name_or_path=None, decoder_pretrained_model_name_or_path=None, *model_args, **kwargs):
def from_pretrained(
cls,
encoder_pretrained_model_name_or_path=None,
decoder_pretrained_model_name_or_path=None,
*model_args,
**kwargs
):
r""" Instantiates an encoder and a decoder from one or two base classes
of the library from pre-trained model checkpoints.
......@@ -108,23 +113,28 @@ class PreTrainedSeq2seq(nn.Module):
# Separate the encoder- and decoder- specific kwargs. A kwarg is
# decoder-specific it the key starts with `decoder_`
kwargs_decoder = {}
kwargs_encoder = kwargs
for key in kwargs_encoder.keys():
if key.startswith("decoder_"):
kwargs_decoder[key.replace("decoder_", "")] = kwargs_encoder.pop(key)
kwargs_encoder = {
argument: value
for argument, value in kwargs.items()
if not argument.startswith("decoder_")
}
kwargs_decoder = {
argument[len("decoder_") :]: value
for argument, value in kwargs.items()
if argument.startswith("decoder_")
}
# Load and initialize the encoder and decoder
# The distinction between encoder and decoder at the model level is made
# by the value of the flag `is_decoder` that we need to set correctly.
encoder = kwargs.pop("encoder_model", None)
encoder = kwargs_encoder.pop("encoder_model", None)
if encoder is None:
kwargs_encoder["is_decoder"] = False
encoder = AutoModel.from_pretrained(
encoder_pretrained_model_name_or_path, *model_args, **kwargs_encoder
)
decoder = kwargs.pop("decoder_model", None)
decoder = kwargs_decoder.pop("model", None)
if decoder is None:
kwargs_decoder["is_decoder"] = True
decoder = AutoModelWithLMHead.from_pretrained(
......@@ -135,6 +145,12 @@ class PreTrainedSeq2seq(nn.Module):
return model
def save_pretrained(self, save_directory):
""" Save a Seq2Seq model and its configuration file in a format
such that it can be loaded using `:func:`~transformers.PreTrainedSeq2seq.from_pretrained` """
self.encoder.save_pretrained(os.path.join(save_directory, "encoder"))
self.decoder.save_pretrained(os.path.join(save_directory, "decoder"))
def forward(self, encoder_input_ids, decoder_input_ids, **kwargs):
""" The forward pass on a seq2eq depends what we are performing:
......@@ -155,22 +171,29 @@ class PreTrainedSeq2seq(nn.Module):
"""
# Separate the encoder- and decoder- specific kwargs. A kwarg is
# decoder-specific it the key starts with `decoder_`
kwargs_decoder = {}
kwargs_encoder = kwargs
for key in kwargs_encoder.keys():
if key.startswith("decoder_"):
kwargs_decoder[key.replace("decoder_", "")] = kwargs_encoder.pop(key)
kwargs_encoder = {
argument: value
for argument, value in kwargs.items()
if not argument.startswith("decoder_")
}
kwargs_decoder = {
argument[len("decoder_") :]: value
for argument, value in kwargs.items()
if argument.startswith("decoder_")
}
# Encode if needed (training, first prediction pass)
encoder_hidden_states = kwargs_encoder.pop("encoder_hidden_states", None)
if encoder_hidden_states is None:
encoder_outputs = self.encoder(encoder_input_ids, **kwargs_encoder)
encoder_hidden_states = encoder_outputs[0][-1] # output of the encoder *stack*
encoder_hidden_states = encoder_outputs[0][
-1
] # output of the encoder *stack*
else:
encoder_outputs = ()
# Decode
kwargs_decoder["encoder_hidden_states"] = encoder_hidden_states
kwargs_decoder["encoder_hidden_states"] = encoder_hidden_states[None, :, :]
decoder_outputs = self.decoder(decoder_input_ids, **kwargs_decoder)
return decoder_outputs + encoder_outputs
......@@ -201,9 +224,25 @@ class Model2Model(PreTrainedSeq2seq):
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
model = super(Model2Model, cls).from_pretrained(encoder_pretrained_model_name_or_path=pretrained_model_name_or_path,
if (
"bert" not in pretrained_model_name_or_path
or "roberta" in pretrained_model_name_or_path
or "distilbert" in pretrained_model_name_or_path
):
raise ValueError("Only the Bert model is currently supported.")
model = super(Model2Model, cls).from_pretrained(
encoder_pretrained_model_name_or_path=pretrained_model_name_or_path,
decoder_pretrained_model_name_or_path=pretrained_model_name_or_path,
**kwargs)
**kwargs
)
# Some architectures require for the decoder to be initialized randomly
# before fine-tuning.
if kwargs.get("decoder_initialize_randomly", False):
model.decoder.init_weights()
return model
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment