"vscode:/vscode.git/clone" did not exist on "aaaed56ffcee3433fa57345b70ff68db8e8bde07"
Commit 9660ba1c authored by Rémi Louf's avatar Rémi Louf Committed by Julien Chaumond
Browse files

Add beam search

parent 1c71ecc8
# coding=utf-8
# Copyright 2019 The HuggingFace Inc. team.
# Copyright (c) 2019 The HuggingFace Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Finetuning seq2seq models for sequence generation."""
import argparse
import functools
import logging
import os
import random
import sys
import numpy as np
from tqdm import tqdm, trange
import torch
from torch.optim import Adam
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from transformers import (
AutoTokenizer,
BertForMaskedLM,
BertConfig,
PreTrainedEncoderDecoder,
Model2Model,
)
from utils_summarization import (
CNNDailyMailDataset,
encode_for_summarization,
fit_to_block_size,
build_lm_labels,
build_mask,
compute_token_type_ids,
)
logger = logging.getLogger(__name__)
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
def set_seed(args):
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
# ------------
# Load dataset
# ------------
def load_and_cache_examples(args, tokenizer):
dataset = CNNDailyMailDataset(tokenizer, data_dir=args.data_dir)
return dataset
def collate(data, tokenizer, block_size):
""" List of tuple as an input. """
# remove the files with empty an story/summary, encode and fit to block
data = filter(lambda x: not (len(x[0]) == 0 or len(x[1]) == 0), data)
data = [
encode_for_summarization(story, summary, tokenizer) for story, summary in data
]
data = [
(
fit_to_block_size(story, block_size, tokenizer.pad_token_id),
fit_to_block_size(summary, block_size, tokenizer.pad_token_id),
)
for story, summary in data
]
stories = torch.tensor([story for story, summary in data])
summaries = torch.tensor([summary for story, summary in data])
encoder_token_type_ids = compute_token_type_ids(stories, tokenizer.cls_token_id)
encoder_mask = build_mask(stories, tokenizer.pad_token_id)
decoder_mask = build_mask(summaries, tokenizer.pad_token_id)
lm_labels = build_lm_labels(summaries, tokenizer.pad_token_id)
return (
stories,
summaries,
encoder_token_type_ids,
encoder_mask,
decoder_mask,
lm_labels,
)
# ----------
# Optimizers
# ----------
class BertSumOptimizer(object):
""" Specific optimizer for BertSum.
As described in [1], the authors fine-tune BertSum for abstractive
summarization using two Adam Optimizers with different warm-up steps and
learning rate. They also use a custom learning rate scheduler.
[1] Liu, Yang, and Mirella Lapata. "Text summarization with pretrained encoders."
arXiv preprint arXiv:1908.08345 (2019).
"""
def __init__(self, model, lr, warmup_steps, beta_1=0.99, beta_2=0.999, eps=1e-8):
self.encoder = model.encoder
self.decoder = model.decoder
self.lr = lr
self.warmup_steps = warmup_steps
self.optimizers = {
"encoder": Adam(
model.encoder.parameters(),
lr=lr["encoder"],
betas=(beta_1, beta_2),
eps=eps,
),
"decoder": Adam(
model.decoder.parameters(),
lr=lr["decoder"],
betas=(beta_1, beta_2),
eps=eps,
),
}
self._step = 0
def _update_rate(self, stack):
return self.lr[stack] * min(
self._step ** (-0.5), self._step * self.warmup_steps[stack] ** (-0.5)
)
def zero_grad(self):
self.optimizer_decoder.zero_grad()
self.optimizer_encoder.zero_grad()
def step(self):
self._step += 1
for stack, optimizer in self.optimizers.items():
new_rate = self._update_rate(stack)
for param_group in optimizer.param_groups:
param_group["lr"] = new_rate
optimizer.step()
# ------------
# Train
# ------------
def train(args, model, tokenizer):
""" Fine-tune the pretrained model on the corpus. """
set_seed(args)
# Load the data
args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
train_dataset = load_and_cache_examples(args, tokenizer)
train_sampler = RandomSampler(train_dataset)
model_collate_fn = functools.partial(collate, tokenizer=tokenizer, block_size=512)
train_dataloader = DataLoader(
train_dataset,
sampler=train_sampler,
batch_size=args.train_batch_size,
collate_fn=model_collate_fn,
)
# Training schedule
if args.max_steps > 0:
t_total = args.max_steps
args.num_train_epochs = t_total // (
len(train_dataloader) // args.gradient_accumulation_steps + 1
)
else:
t_total = (
len(train_dataloader)
// args.gradient_accumulation_steps
* args.num_train_epochs
)
# Prepare the optimizer
lr = {"encoder": 0.002, "decoder": 0.2}
warmup_steps = {"encoder": 20000, "decoder": 10000}
optimizer = BertSumOptimizer(model, lr, warmup_steps)
# Train
logger.info("***** Running training *****")
logger.info(" Num examples = %d", len(train_dataset))
logger.info(" Num Epochs = %d", args.num_train_epochs)
logger.info(
" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size
)
logger.info(
" Total train batch size (w. parallel, distributed & accumulation) = %d",
args.train_batch_size * args.gradient_accumulation_steps
# * (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
)
logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
logger.info(" Total optimization steps = %d", t_total)
model.zero_grad()
train_iterator = trange(args.num_train_epochs, desc="Epoch", disable=True)
global_step = 0
tr_loss = 0.0
for _ in train_iterator:
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=True)
for step, batch in enumerate(epoch_iterator):
source, target, encoder_token_type_ids, encoder_mask, decoder_mask, lm_labels = batch
source = source.to(args.device)
target = target.to(args.device)
encoder_token_type_ids = encoder_token_type_ids.to(args.device)
encoder_mask = encoder_mask.to(args.device)
decoder_mask = decoder_mask.to(args.device)
lm_labels = lm_labels.to(args.device)
model.train()
outputs = model(
source,
target,
encoder_token_type_ids=encoder_token_type_ids,
encoder_attention_mask=encoder_mask,
decoder_attention_mask=decoder_mask,
decoder_lm_labels=lm_labels,
)
loss = outputs[0]
print(loss)
if args.gradient_accumulation_steps > 1:
loss /= args.gradient_accumulation_steps
loss.backward()
tr_loss += loss.item()
if (step + 1) % args.gradient_accumulation_steps == 0:
torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
optimizer.step()
model.zero_grad()
global_step += 1
if args.max_steps > 0 and global_step > args.max_steps:
epoch_iterator.close()
break
if args.max_steps > 0 and global_step > args.max_steps:
train_iterator.close()
break
return global_step, tr_loss / global_step
# ------------
# Train
# ------------
def evaluate(args, model, tokenizer, prefix=""):
set_seed(args)
args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
eval_dataset = load_and_cache_examples(args, tokenizer, evaluate=True)
eval_sampler = SequentialSampler(eval_dataset)
eval_dataloader = DataLoader(
eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size
)
# multi-gpu evaluate
if args.n_gpu > 1:
model = torch.nn.DataParallel(model)
logger.info("***** Running evaluation {} *****".format(prefix))
logger.info(" Num examples = %d", len(eval_dataset))
logger.info(" Batch size = %d", args.eval_batch_size)
eval_loss = 0.0
nb_eval_steps = 0
model.eval()
for batch in tqdm(eval_dataloader, desc="Evaluating"):
source, target, encoder_token_type_ids, encoder_mask, decoder_mask, lm_labels = batch
source = source.to(args.device)
target = target.to(args.device)
encoder_token_type_ids = encoder_token_type_ids.to(args.device)
encoder_mask = encoder_mask.to(args.device)
decoder_mask = decoder_mask.to(args.device)
lm_labels = lm_labels.to(args.device)
with torch.no_grad():
outputs = model(
source,
target,
encoder_token_type_ids=encoder_token_type_ids,
encoder_attention_mask=encoder_mask,
decoder_attention_mask=decoder_mask,
decoder_lm_labels=lm_labels,
)
lm_loss = outputs[0]
eval_loss += lm_loss.mean().item()
nb_eval_steps += 1
eval_loss = eval_loss / nb_eval_steps
perplexity = torch.exp(torch.tensor(eval_loss))
result = {"perplexity": perplexity}
# Save the evaluation's results
output_eval_file = os.path.join(args.output_dir, "eval_results.txt")
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
with open(output_eval_file, "w") as writer:
logger.info("***** Eval results {} *****".format(prefix))
for key in sorted(result.keys()):
logger.info(" %s = %s", key, str(result[key]))
writer.write("%s = %s\n" % (key, str(result[key])))
return result
def save_model_checkpoints(args, model, tokenizer):
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
logger.info("Saving model checkpoint to %s", args.output_dir)
# Save a trained model, configuration and tokenizer using `save_pretrained()`.
# They can then be reloaded using `from_pretrained()`
model_to_save = (
model.module if hasattr(model, "module") else model
) # Take care of distributed/parallel training
model_to_save.save_pretrained(args.output_dir, model_type='bert')
tokenizer.save_pretrained(args.output_dir)
torch.save(args, os.path.join(args.output_dir, "training_arguments.bin"))
def main():
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument(
"--data_dir",
default=None,
type=str,
required=True,
help="The input training data file (a text file).",
)
parser.add_argument(
"--output_dir",
default=None,
type=str,
required=True,
help="The output directory where the model predictions and checkpoints will be written.",
)
# Optional parameters
parser.add_argument(
"--gradient_accumulation_steps",
type=int,
default=1,
help="Number of updates steps to accumulate before performing a backward/update pass.",
)
parser.add_argument(
"--do_evaluate",
type=bool,
default=False,
help="Run model evaluation on out-of-sample data.",
)
parser.add_argument("--do_train", type=bool, default=False, help="Run training.")
parser.add_argument(
"--do_overwrite_output_dir",
type=bool,
default=False,
help="Whether to overwrite the output dir.",
)
parser.add_argument(
"--model_name_or_path",
default="bert-base-cased",
type=str,
help="The model checkpoint to initialize the encoder and decoder's weights with.",
)
parser.add_argument(
"--model_type",
default="bert",
type=str,
help="The decoder architecture to be fine-tuned.",
)
parser.add_argument(
"--max_grad_norm", default=1.0, type=float, help="Max gradient norm."
)
parser.add_argument(
"--max_steps",
default=-1,
type=int,
help="If > 0: set total number of training steps to perform. Override num_train_epochs.",
)
parser.add_argument(
"--to_cpu", default=False, type=bool, help="Whether to force training on CPU."
)
parser.add_argument(
"--num_train_epochs",
default=10,
type=int,
help="Total number of training epochs to perform.",
)
parser.add_argument(
"--per_gpu_train_batch_size",
default=4,
type=int,
help="Batch size per GPU/CPU for training.",
)
parser.add_argument("--seed", default=42, type=int)
args = parser.parse_args()
if (
os.path.exists(args.output_dir)
and os.listdir(args.output_dir)
and args.do_train
and not args.do_overwrite_output_dir
):
raise ValueError(
"Output directory ({}) already exists and is not empty. Use --do_overwrite_output_dir to overwrite.".format(
args.output_dir
)
)
# Set up training device
if args.to_cpu or not torch.cuda.is_available():
args.device = torch.device("cpu")
args.n_gpu = 0
else:
args.device = torch.device("cuda")
args.n_gpu = torch.cuda.device_count()
# Load pretrained model and tokenizer. The decoder's weights are randomly initialized.
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
config = BertConfig.from_pretrained(args.model_name_or_path)
decoder_model = BertForMaskedLM(config)
model = Model2Model.from_pretrained(
args.model_name_or_path, decoder_model=decoder_model
)
# Setup logging
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
logger.warning(
"Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
0,
args.device,
args.n_gpu,
False,
False,
)
logger.info("Training/evaluation parameters %s", args)
# Train the model
model.to(args.device)
if args.do_train:
try:
global_step, tr_loss = train(args, model, tokenizer)
except KeyboardInterrupt:
response = input("You interrupted the training. Do you want to save the model checkpoints? [Y/n]")
if response.lower() in ["", "y", "yes"]:
save_model_checkpoints(args, model, tokenizer)
sys.exit(0)
logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
save_model_checkpoints(args, model, tokenizer)
# Evaluate the model
results = {}
if args.do_evaluate:
checkpoints = [args.output_dir]
logger.info("Evaluate the following checkpoints: %s", checkpoints)
for checkpoint in checkpoints:
encoder_checkpoint = os.path.join(checkpoint, "bert_encoder")
decoder_checkpoint = os.path.join(checkpoint, "bert_decoder")
model = PreTrainedEncoderDecoder.from_pretrained(
encoder_checkpoint, decoder_checkpoint
)
model.to(args.device)
print("model loaded")
return results
if __name__ == "__main__":
main()
...@@ -25,9 +25,8 @@ class CNNDailyMailDataset(Dataset): ...@@ -25,9 +25,8 @@ class CNNDailyMailDataset(Dataset):
[2] https://github.com/abisee/cnn-dailymail/ [2] https://github.com/abisee/cnn-dailymail/
""" """
def __init__(self, tokenizer, prefix="train", data_dir=""): def __init__(self, data_dir="", prefix="train"):
assert os.path.isdir(data_dir) assert os.path.isdir(data_dir)
self.tokenizer = tokenizer
# We initialize the class by listing all the files that contain # We initialize the class by listing all the files that contain
# stories and summaries. Files are not read in memory given # stories and summaries. Files are not read in memory given
...@@ -104,31 +103,30 @@ def _add_missing_period(line): ...@@ -104,31 +103,30 @@ def _add_missing_period(line):
# -------------------------- # --------------------------
def fit_to_block_size(sequence, block_size, pad_token): def fit_to_block_size(sequence, block_size, pad_token_id):
""" Adapt the source and target sequences' lengths to the block size. """ Adapt the source and target sequences' lengths to the block size.
If the sequence is shorter than the block size we pad it with -1 ids If the sequence is shorter we append padding token to the right of the sequence.
which correspond to padding tokens.
""" """
if len(sequence) > block_size: if len(sequence) > block_size:
return sequence[:block_size] return sequence[:block_size]
else: else:
sequence.extend([pad_token] * (block_size - len(sequence))) sequence.extend([pad_token_id] * (block_size - len(sequence)))
return sequence return sequence
def build_lm_labels(sequence, pad_token): def build_lm_labels(sequence, pad_token_id):
""" Padding token, encoded as 0, are represented by the value -1 so they """ Padding token are replaced by the value -1 so they
are not taken into account in the loss computation. """ are not taken into account in the loss computation. """
padded = sequence.clone() padded = sequence.clone()
padded[padded == pad_token] = -1 padded[padded == pad_token_id] = -1
return padded return padded
def build_mask(sequence, pad_token): def build_mask(sequence, pad_token_id):
""" Builds the mask. The attention mechanism will only attend to positions """ Builds the mask. The attention mechanism will only attend to positions
with value 1. """ with value 1. """
mask = torch.ones_like(sequence) mask = torch.ones_like(sequence)
idx_pad_tokens = sequence == pad_token idx_pad_tokens = sequence == pad_token_id
mask[idx_pad_tokens] = 0 mask[idx_pad_tokens] = 0
return mask return mask
......
from .beam_search import BeamSearch
# coding=utf-8 # coding=utf-8
# Copyright (c) 2019 Yang Liu # MIT License
# Permission is hereby granted, free of charge, to any person obtaining a copy # Copyright (c) 2017-Present OpenNMT
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights # Permission is hereby granted, free of charge, to any person obtaining a copy of
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # this software and associated documentation files (the "Software"), to deal in
# copies of the Software, and to permit persons to whom the Software is # the Software without restriction, including without limitation the rights to
# furnished to do so, subject to the following conditions: # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies
# of the Software, and to permit persons to whom the Software is furnished to do
# so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included in all # The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software. # copies or substantial portions of the Software.
...@@ -19,69 +21,161 @@ ...@@ -19,69 +21,161 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE. # SOFTWARE.
""" """
A general wrapper around models with LM heads to generate sequences Use Beam Search to generate sequences using encoder-decoder models.
using beam search.
""" """
import torch import torch
from torch import nn from torch import nn
class TransformerBeamSearch(nn.Module): class BeamSearch(nn.Module):
def __init__( def __init__(
self, self,
model, model,
tokenizer, tokenizer,
batch_size,
beam_size, beam_size,
min_length, min_length,
max_length, max_length,
batch_size=1,
alpha=0, alpha=0,
block_repeating_trigram=True, block_repeating_trigrams=True,
): ):
r"""
Inputs:
**model**: instance of ``transformers.PreTrainedEncoderDecoder``
The pretrained encoder-decoder model that will be used to generate the sequences.
**tokenizer**: instance of ``transformers.PreTrainedTokenizer``
The pretrained tokenizer associated to the model used in the encoder-decoder. We only
support encoder-decoder that use the same tokenizer for encoder and decoder. The tokenizer
needs to be initialized or this function will raise and exception.
**batch_size**: (`optional`) int
Batch size of the inputs. The value is set automatically when calling `forward`.
**beam_size**: int
Number of beams that are used for each element on the batch.
**min_length**: int
Minimum number of steps performed by the beam search before terminating.
**max_length**: int
Maximum number of steps performed by the beam search. Any beam that has not finished
will return its current solution with the highest probability. The sequence that is
returned has a length of max_length-1 to account for the end token that is subsequently added.
**alpha**: float
Parameter of the length penalty. Read the documentation of the `_length_penalty` method for mode details.
**block_repeating_trigrams**: bool
Whether to block sequences that have repeating 3-grams.
""" """
Attributes: super(BeamSearch, self).__init__()
mask_word_id: token id that corresponds to the mask
"""
super(TransformerBeamSearch, self).__init__()
self.model = model self.model = model
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.start_token_id = tokenizer.start_token_id self.bos_token_id = tokenizer.bos_token_id
self.end_token_id = tokenizer.end_token_id self.eos_token_id = tokenizer.eos_token_id
self.pad_token_id = tokenizer.pad_token_id self.pad_token_id = tokenizer.pad_token_id
self.batch_size = batch_size
self.beam_size = beam_size self.beam_size = beam_size
self.min_length = min_length self.min_length = min_length
self.max_length = max_length self.max_length = max_length
self.block_repeating_trigram = block_repeating_trigram self.block_repeating_trigram = block_repeating_trigrams
self.apply_length_penalty = False if alpha == 0 else True self.apply_length_penalty = False if alpha == 0 else True
self.alpha = alpha self.alpha = alpha
# State of the beam self._init_beam_state(batch_size)
def __len__(self):
try:
return self.growing_beams.size(1)
except NameError:
return 0
def _init_beam_state(self, batch_size):
""" (re-)Initialize the state of the beams. """
self.hypotheses = [[] for _ in range(batch_size)] self.hypotheses = [[] for _ in range(batch_size)]
self.batch_offset = torch.arange(batch_size, dtype=torch.long) self.batch_offset = torch.arange(batch_size, dtype=torch.long)
self.beam_offset = torch.arange( self.beam_offset = torch.arange(
0, batch_size * self.beam_size, step=self.beam_size, dtype=torch.long 0, batch_size * self.beam_size, step=self.beam_size, dtype=torch.long
) )
self.growing_beam = torch.full( self.growing_beams = torch.full(
(batch_size * self.beam_size, 1), self.start_token_id, dtype=torch.long (batch_size * self.beam_size, 1), self.bos_token_id, dtype=torch.long
) )
self.topk_log_probabilities = torch.tensor( self.topk_log_probabilities = torch.tensor(
[0.0] + [float("-inf")] * (self.beam_size - 1), dtype=torch.float [0.0] + [float("-inf")] * (self.beam_size - 1), dtype=torch.float
).repeat(batch_size) ).repeat(batch_size)
self.results = { self.results = {
"prediction": [[] for _ in batch_size], "predictions": [[] for _ in range(batch_size)],
"scores": [[] for _ in batch_size], "scores": [[] for _ in range(batch_size)],
} }
self._step = 0 self._step = 0
self.is_done = False self.is_done = False
def step(self, log_probabilities): def forward(self, encoder_input_ids, **model_kwargs):
""" Grows the beam by one step. """ """ Generate a sequence using Beam Search. """
# keyword arguments come in 3 flavors: encoder-specific (prefixed by
# `encoder_`), decoder-specific (prefixed by `decoder_`) and those
# that apply to the model as whole.
# We let the specific kwargs override the common ones in case of conflict.
kwargs_common = {
argument: value
for argument, value in model_kwargs.items()
if not argument.startswith("encoder_") and not argument.startswith("decoder_")
}
kwargs_decoder = kwargs_common.copy()
kwargs_encoder = kwargs_common.copy()
kwargs_encoder.update(
{
argument[len("encoder_") :]: value
for argument, value in model_kwargs.items()
if argument.startswith("encoder_")
}
)
kwargs_decoder.update(
{
argument[len("decoder_") :]: value
for argument, value in model_kwargs.items()
if argument.startswith("decoder_")
}
)
# forward pass on the encoder
encoder_outputs = self.model.encoder.forward(encoder_input_ids, kwargs_encoder)
kwargs_decoder["encoder_hidden_states"] = tile(
encoder_outputs, self.beam_size, dim=0
)
# grow the beam by generating sequences in an autoregressive way
batch_size = encoder_input_ids.size(0)
self._init_beam_state(batch_size)
for step in range(self.max_length):
# prepare the decoder input
decoder_input = fit_to_block_size(
self.growing_beams, self.tokenizer.pad_token_id
)
kwargs_decoder["decoder_lm_labels"] = build_lm_labels(
decoder_input, self.tokenizer.pad_token_id
)
kwargs_decoder["decoder_attention_mask"] = build_mask(
decoder_input, self.tokenizer.pad_token_id
)
outputs = self.model.decoder(decoder_input, kwargs_decoder)
log_probabilities = torch.nn.functional.log_softmax(outputs[1])
surviving_beams_rows = self.grow(log_probabilities)
if self.is_done:
break
kwargs_decoder["encoder_hidden_states"] = kwargs_decoder[
"encoder_hidden_states"
].index_select(0, surviving_beams_rows)
kwargs_decoder["encoder_attention_mask"] = kwargs_decoder[
"encoder_attention_mask"
].index_select(0, surviving_beams_rows)
return self.results
def grow(self, log_probabilities):
""" Grow the beams by one step. """
self._step += 1 self._step += 1
# The batch size changes as some beams finish so we define _B # The number of beams changes as some beams finish so we define _B
vocab_size = log_probabilities.size(-1) vocab_size = log_probabilities.size(-1)
_B = log_probabilities.size(0) // self.beam_size _B = log_probabilities.size(0) // self.beam_size
...@@ -89,21 +183,21 @@ class TransformerBeamSearch(nn.Module): ...@@ -89,21 +183,21 @@ class TransformerBeamSearch(nn.Module):
# next token (conditioned on the words in the beam). # next token (conditioned on the words in the beam).
log_probabilities += self.topk_log_probabilities.view(-1, 1) log_probabilities += self.topk_log_probabilities.view(-1, 1)
self.enforce_min_length(log_probabilities) self._enforce_min_length(log_probabilities)
if self.block_repeating_trigram: if self.block_repeating_trigram:
self.remove_repeating_trigrams(log_probabilities, _B) self._remove_beams_with_repeating_trigrams(log_probabilities, _B)
# Find the `beam_size` (previous_beam + token) combinations with # Find the `beam_size` (previous_beam + token) combinations with
# the highest score # the highest score
topk_log_probabilities, topk_ids = log_probabilities.topk( topk_log_probabilities, topk_ids = torch.topk(
log_probabilities.view(_B, self.beam_size * vocab_size), log_probabilities.view(_B, self.beam_size * vocab_size), self.beam_size, dim=1
self.beam_size,
dim=1,
) )
# Apply the length penalty. The +1 accounts for the [EOS] token # Apply the length penalty. The +1 accounts for the [EOS] token
# that will be added if the beam ends. # that will be added if the beam ends.
topk_scores = topk_log_probabilities / self.length_penalty() topk_scores = topk_log_probabilities
if self.apply_length_penalty:
topk_scores /= self._length_penalty()
# Retrieve the corresponding respective beam and token id # Retrieve the corresponding respective beam and token id
# topk_token_ids[i] will be added to topk_beam_ids[i] # topk_token_ids[i] will be added to topk_beam_ids[i]
...@@ -112,14 +206,13 @@ class TransformerBeamSearch(nn.Module): ...@@ -112,14 +206,13 @@ class TransformerBeamSearch(nn.Module):
# Retrieve the row index of the surviving beams in the original # Retrieve the row index of the surviving beams in the original
# view of the log_probabilities tensor # view of the log_probabilities tensor
surviving_beams_rows = (topk_beam_ids + self.beam_offset[:_B].view(-1, 1)).view( surviving_beams_per_batch = topk_beam_ids + self.beam_offset[:_B].view(-1, 1)
-1 surviving_beams_rows = surviving_beams_per_batch.view(-1)
)
# Append the last predictions # Append the last predictions
self.growing_beam = torch.cat( self.growing_beams = torch.cat(
[ [
self.growing_beam.index_select(0, surviving_beams_rows), self.growing_beams.index_select(0, surviving_beams_rows),
topk_token_ids.view(-1, 1), topk_token_ids.view(-1, 1),
], ],
1, 1,
...@@ -128,21 +221,38 @@ class TransformerBeamSearch(nn.Module): ...@@ -128,21 +221,38 @@ class TransformerBeamSearch(nn.Module):
# Check if any of the beam searches has ended during this # Check if any of the beam searches has ended during this
# growth step. Also if top beam (most probable) has ended # growth step. Also if top beam (most probable) has ended
# for one element of the batch. # for one element of the batch.
is_finished = topk_token_ids.eq(self.end_token_id) is_finished = topk_token_ids.eq(self.eos_token_id)
self.enforce_max_length() self._enforce_max_length(is_finished)
is_top_beam_finished = is_finished[:, 0].eq(1) if is_finished.any():
non_finished = self._cut_finished(is_finished, topk_scores)
self.batch_offset = self.batch_offset.index_select(0, non_finished)
surviving_beams_per_batch = surviving_beams_per_batch.index_select(
0, non_finished
)
self.topk_log_probabilities = self.topk_log_probabilities.index_select(
0, non_finished
)
surviving_beams_rows = surviving_beams_per_batch.view(-1)
self.growing_beams = self.growing_beams.index_select(0, surviving_beams_rows)
return surviving_beams_rows
def _cut_finished(self, is_finished, topk_scores):
""" Save the finished searches and cut the correponding sequences off
the beams. """
is_top_beam_finished = is_finished[:, 0].eq(True)
# Save the finished searches # Save the finished searches
if is_finished.any(): predictions = self.growing_beams.view(
predictions = self.growing_beam.view( -1, self.beam_size, self.growing_beams.size(1)
-1, self.beam_size, self.growing_beam.size(1)
) )
for i in range(is_finished.size(0)): for i in range(is_finished.size(0)):
if is_top_beam_finished[i]: if is_top_beam_finished[i]:
is_finished[i].fill_(1) is_finished[i].fill_(1)
finished_hyp = is_finished[i].nonzero().view(-1) finished_hyp = is_finished[i].nonzero().view(-1)
# Store finished hypotheses for this batch. # Store the finished beams as a (score, prediction) hypothesis.
b = self.batch_offset[i] b = self.batch_offset[i]
for j in finished_hyp: for j in finished_hyp:
self.hypotheses[b].append((topk_scores[i, j], predictions[i, j, :])) self.hypotheses[b].append((topk_scores[i, j], predictions[i, j, :]))
...@@ -150,95 +260,44 @@ class TransformerBeamSearch(nn.Module): ...@@ -150,95 +260,44 @@ class TransformerBeamSearch(nn.Module):
# If the batch reached the end, save the best hypotheses # If the batch reached the end, save the best hypotheses
# in terms of length-penalized score. # in terms of length-penalized score.
if is_top_beam_finished[i]: if is_top_beam_finished[i]:
best_hyp = sorted( best_score, best_prediction = max(self.hypotheses[b], key=lambda x: x[0])
self.hypotheses[b], key=lambda x: x[0], reverse=True
)
best_score, best_prediction = best_hyp[0]
self.results["scores"][b].append(best_score) self.results["scores"][b].append(best_score)
self.results["predictions"][b].append(best_prediction) self.results["predictions"][b].append(best_prediction)
non_finished = is_top_beam_finished.eq(0).nonzero().view(-1) non_finished = is_top_beam_finished.eq(False).nonzero().view(-1)
if len(non_finished) == 0: if len(non_finished) == 0:
self.is_done = True self.is_done = True
# Remove finished batches for the next step. return non_finished
topk_log_probabilities = topk_log_probabilities.index_select(
0, non_finished
)
self.batch_offset = self.batch_offset.index_select(0, non_finished)
self.growing_beam = predictions.index_select(0, non_finished).view(
-1, self.growing_beam.size(-1)
)
surviving_beams_rows = surviving_beams_rows.index_select(0, non_finished)
return surviving_beams_rows def _remove_beams_with_repeating_trigrams(self, log_probabilities, _B):
if self._step + 1 > 3: # [BOS] does not count
def forward(self, encoder_input_ids, **kwargs):
# keyword arguments come in 3 flavors: encoder-specific (prefixed by
# `encoder_`), decoder-specific (prefixed by `decoder_`) and those
# that apply to the model as whole.
# We let the specific kwargs override the common ones in case of conflict.
kwargs_encoder = {
argument[len("encoder_"):]: value
for argument, value in kwargs.items()
if argument.startswith("encoder_")
}
kwargs_decoder = {
argument[len("decoder_"):]: value
for argument, value in kwargs.items()
if argument.startswith("decoder_")
}
kwargs_common = {
argument: value
for argument, value in kwargs.items()
if not (argument.startswith("encoder_") or argument.startswith("decoder_"))
}
kwargs_decoder = dict(kwargs_common, **kwargs_decoder)
kwargs_encoder = dict(kwargs_common, **kwargs_encoder)
# forward pass on the encoder
encoder_outputs = self.model.encoder.forward(encoder_input_ids, kwargs_encoder)
kwargs_decoder["encoder_hidden_states"] = tile(
encoder_outputs, self.beam_size, dim=0
)
# grow the beam by generating sequences in an autoregressive way
self.growing_beam = torch.full(
(self.batch_size * self.beam_size, 1), self.start_token_id, dtype=torch.long
)
for step in range(self.max_length):
decoder_input = self.growing_beam[:, -1]
outputs = self.model.decoder(decoder_input, kwargs_decoder)
log_probabilities = torch.nn.functional.log_softmax(outputs[1])
surviving_beams_rows = self.step(log_probabilities)
if self.is_done:
break
kwargs_decoder["encoder_hidden_states"] = kwargs_decoder[
"encoder_hidden_states"
].index_select(0, surviving_beams_rows)
return self.results
def remove_repeating_trigrams(self, log_probabilities, _B):
if(self._step + 1 > 3):
for i in range(_B * self.beam_size): for i in range(_B * self.beam_size):
tokens = [t for t in self.growing_beam[i]] tokens = self.growing_beams[i]
trigrams = [(tokens[i-1], tokens[i], tokens[i+1]) for i in range(1, len(words) - 1)] trigrams = [
(tokens[j - 1], tokens[j], tokens[j + 1])
for j in range(1, len(self) - 1)
]
last_trigram = tuple(trigrams[-1]) last_trigram = tuple(trigrams[-1])
if last_trigram in trigrams[:-1]: if last_trigram in trigrams[:-1]:
log_probabilities[i] = -1e20 log_probabilities[i] = -1e20
def enforce_min_length(self): def _enforce_min_length(self, log_probabilities):
if self._step < self.min_length: if self._step < self.min_length:
self.log_probabilities[self.end_token_id] = -1e20 log_probabilities[:, self.eos_token_id] = -1e20
def enforce_max_length(self): def _enforce_max_length(self, is_finished):
# +1 because we will need to add an [EOS] token
if self._step + 1 == self.max_length: if self._step + 1 == self.max_length:
self.is_finished.fill_(1) is_finished.fill_(1)
def _length_penalty(self):
""" The calculation of the length penalty follows that of [1].
def length_penalty(self): [1] Wu, Yonghui, et al. "Google's neural machine translation system:
Bridging the gap between human and machine translation." arXiv preprint
arXiv:1609.08144 (2016).
"""
return ((5.0 + (self._step + 1)) / 6.0) ** self.alpha return ((5.0 + (self._step + 1)) / 6.0) ** self.alpha
...@@ -269,3 +328,31 @@ def tile(x, count, dim=0): ...@@ -269,3 +328,31 @@ def tile(x, count, dim=0):
if dim != 0: if dim != 0:
x = x.permute(perm).contiguous() x = x.permute(perm).contiguous()
return x return x
def fit_to_block_size(sequence, block_size, pad_token_id):
""" Adapt the source and target sequences' lengths to the block size.
If the sequence is shorter we append padding tokens to the right.
"""
if len(sequence) > block_size:
return sequence[:block_size]
else:
sequence.extend([pad_token_id] * (block_size - len(sequence)))
return sequence
def build_lm_labels(sequence, pad_token_id):
""" Padding token, encoded as 0, are represented by the value -1 so they
are not taken into account in the loss computation. """
padded = sequence.clone()
padded[padded == pad_token_id] = -1
return padded
def build_mask(sequence, pad_token_id):
""" Builds the mask. The attention mechanism will only attend to positions
with value 1. """
mask = torch.ones_like(sequence)
idx_pad_tokens = sequence == pad_token_id
mask[idx_pad_tokens] = 0
return mask
from collections import namedtuple
import unittest
import numpy as np
import torch
from transformers.generate import BeamSearch
from transformers import PreTrainedEncoderDecoder
StubTokenizer = namedtuple("Tokenizer", ["bos_token_id", "eos_token_id", "pad_token_id"])
StubTransformer = namedtuple("Transformer", ["encoder", "decoder"])
class BeamSearchtest(unittest.TestCase):
def test_beam_search_encoder_decoder_integration(self):
""" We make sure that no internal change in the PreTrainedEncoderDecoder
class will break the integration with the beam search.
"""
model = PreTrainedEncoderDecoder("encoder", "decoder")
tokenizer = StubTokenizer(0, 1, 2)
try:
_ = BeamSearch(
model=model,
tokenizer=tokenizer,
batch_size=1,
beam_size=1,
min_length=1,
max_length=1,
alpha=0,
block_repeating_trigrams=False,
)
except:
self.fail("Instantiating BeamSearch with a PreTrainedEncoderDecoder failed.")
def test_beam_search_min_length(self):
""" We keep predicting the end_token for the first beam and check that
it is not marked as finished until the beam has reached the minimum
length. """
eos_idx = 3
vocab_size = 10
batch_size = 3
beam_size = 2
min_length = 5
beam = BeamSearch(
model=StubTransformer("encoder", "decoder"),
tokenizer=StubTokenizer(bos_token_id=0, eos_token_id=eos_idx, pad_token_id=2),
batch_size=batch_size,
beam_size=beam_size,
min_length=5,
max_length=10,
alpha=0,
block_repeating_trigrams=False,
)
# To test that the minimum length is correctly enforced we constantly
# assign the highest probability to the [EOS] token (and assign lower
# probabilities to some other tokens).
# Since BeamSearch will reset its probability to 1e-20 as long as
# min_length has not been reached, we need to reset the value between
# steps.
non_eos_idxs = [4, 5, 1, 8, 9]
score_distribution = torch.log_softmax(
torch.tensor([6.0, 5.0, 4.0, 3.0, 2.0, 1.0]), dim=0
)
log_probabilities = torch.full((batch_size * beam_size, vocab_size), float("-inf"))
log_probabilities[0, eos_idx] = score_distribution[0]
for idx, score in zip(non_eos_idxs, score_distribution[1:]):
log_probabilities[0, idx] = score
for step in range(1, min_length + 2):
log_probabilities[0, eos_idx] = score_distribution[0]
# Beam #3 and #4 teminate at the first step since the probability
# of the [EOS] token is -1e20 > -\infty so there are only two beams left.
surviving_beams_rows = beam.grow(log_probabilities)
if step < min_length:
np.testing.assert_array_equal(
beam.growing_beams.numpy(),
np.repeat(np.array([[0] + [4] * step]), 2, axis=0),
)
elif step == min_length:
np.testing.assert_array_equal(surviving_beams_rows.numpy(), np.array([]))
self.assertTrue(beam.is_done)
break
log_probabilities = log_probabilities.index_select(0, surviving_beams_rows)
def test_beam_search_max_length(self):
""" We keep predicting the same non-EOS token until we reach the
maximum permitted length """
batch_size = 3
beam_size = 2
max_length = 5
vocab_size = 10
beam = BeamSearch(
model=StubTransformer("encoder", "decoder"),
tokenizer=StubTokenizer(bos_token_id=0, eos_token_id=1, pad_token_id=2),
batch_size=batch_size,
beam_size=beam_size,
min_length=2,
max_length=max_length,
alpha=0,
block_repeating_trigrams=False,
)
log_probabilities = torch.full((batch_size * beam_size, vocab_size), float("-inf"))
# To test that beam search enforces the max length constraint we
# keep giving the highest probability to a token that is not the
# [EOS] token.
# The beam search will stop at max_length-1, assuming that one would
# add the [EOS] token at the end of the returned sequence.
token_idxs = [3, 4, 5]
score_distribution = torch.log_softmax(torch.tensor([10.0, 6.0, 4.0]), dim=0)
for idx, score in zip(token_idxs, score_distribution):
log_probabilities[:, idx] = score
for step in range(1, max_length + 2):
surviving_beams_rows = beam.grow(log_probabilities)
if step + 1 < max_length:
self.assertFalse(beam.is_done)
elif step + 1 == max_length: # Now [EOS] is the most probable token
np.testing.assert_array_equal(surviving_beams_rows.numpy(), np.array([]))
self.assertTrue(beam.is_done)
break
log_probabilities = log_probabilities.index_select(0, surviving_beams_rows)
def test_beam_search_block_repeating_trigrams(self):
""" We make sure that the beams that contain repeating trigrams are removed. """
batch_size = 3
beam_size = 2
max_length = 10
vocab_size = 10
beam = BeamSearch(
model=StubTransformer("encoder", "decoder"),
tokenizer=StubTokenizer(bos_token_id=0, eos_token_id=1, pad_token_id=2),
batch_size=batch_size,
beam_size=beam_size,
min_length=2,
max_length=max_length,
alpha=0,
block_repeating_trigrams=True,
)
log_probabilities = torch.full((batch_size * beam_size, vocab_size), float("-inf"))
# To test that BeamSearch enforces the 3-gram constraint we give the
# highest probably to the same tokens in a cyclic fashion and make sure
# they disappear once the cycle has completed.
token_idxs = [3, 4, 5]
score_distribution = torch.log_softmax(torch.tensor([10.0, 6.0, 4.0]), dim=0)
for idx, score in zip(token_idxs, score_distribution):
log_probabilities[:, idx] = score
for step in range(1, max_length + 2):
# Rotate the probabilities at each step
for idx in token_idxs:
score = score_distribution[(idx + step) % 3]
log_probabilities[::beam_size, idx] = score
surviving_beams_rows = beam.grow(log_probabilities)
log_probabilities = log_probabilities.index_select(0, surviving_beams_rows)
if step < 7:
self.assertFalse(
np.array_equal(
log_probabilities.numpy()[0, :],
np.array([-1e20] * vocab_size, dtype="float32"),
)
)
if step == 7:
np.testing.assert_array_equal(
log_probabilities.numpy()[0, :],
np.array([-1e20] * vocab_size, dtype="float32"),
)
def test_beam_search_example_for_one_step(self):
""" We test that the predictions for one step of growth are correct. """
batch_size = 2
beam_size = 2
max_length = 10
vocab_size = 5
beam = BeamSearch(
model=StubTransformer("encoder", "decoder"),
tokenizer=StubTokenizer(bos_token_id=0, eos_token_id=1, pad_token_id=2),
batch_size=batch_size,
beam_size=beam_size,
min_length=2,
max_length=max_length,
alpha=0,
block_repeating_trigrams=False,
)
log_probabilities = torch.full((batch_size * beam_size, vocab_size), float("-inf"))
log_probabilities[0, 3:] = torch.log_softmax(torch.tensor([2.0, 1.0]), dim=0)
log_probabilities[2, 3:] = torch.log_softmax(torch.tensor([1.0, 2.0]), dim=0)
# First pass
surviving_beams_rows = beam.grow(log_probabilities)
np.testing.assert_array_equal(surviving_beams_rows.numpy(), np.array([0, 0, 2, 2]))
np.testing.assert_array_equal(
beam.growing_beams.numpy(), np.array([[0, 3], [0, 4], [0, 4], [0, 3]])
)
self.assertFalse(beam.is_done)
# Second pass
surviving_beams_rows = beam.grow(log_probabilities)
np.testing.assert_array_equal(surviving_beams_rows.numpy(), np.array([0, 0, 2, 2]))
np.testing.assert_array_equal(
beam.growing_beams.numpy(),
np.array([[0, 3, 3], [0, 3, 4], [0, 4, 4], [0, 4, 3]]),
)
self.assertFalse(beam.is_done)
if __name__ == "__name__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment