Commit 4735c2af authored by Rémi Louf's avatar Rémi Louf Committed by Julien Chaumond
Browse files

tweaks to the BeamSearch API

parent ba089c78
......@@ -32,7 +32,7 @@ import logging
logger = logging.getLogger(__name__)
class BeamSearch(nn.Module):
class BeamSearch(object):
def __init__(
self,
model,
......@@ -45,12 +45,17 @@ class BeamSearch(nn.Module):
max_length,
alpha=0,
block_repeating_trigrams=True,
device=torch.device("cpu"),
):
r"""
Inputs:
**model**: instance of ``transformers.PreTrainedEncoderDecoder``
The pretrained encoder-decoder model that will be used to generate the sequences.
**bos_token_id**: int
Id that is used by the tokenizer to represent the beggining of a sentence.
**pad_token_id**: int
Id that is used by the tokenizer for padding.
**eos_token_id**: int
Id that is used by the tokenizer to represent the end of a sentence.
**batch_size**: (`optional`) int
Batch size of the inputs. The value is set automatically when calling `forward`.
**beam_size**: int
......@@ -68,7 +73,7 @@ class BeamSearch(nn.Module):
"""
super(BeamSearch, self).__init__()
self.model = model
self.device = device
self.device = next(model.parameters()).device # only works if all parameters of the model are stored on a single GPU
self.bos_token_id = bos_token_id
self.eos_token_id = eos_token_id
......@@ -86,10 +91,7 @@ class BeamSearch(nn.Module):
self._init_beam_state(batch_size)
def __len__(self):
try:
return self.growing_beams.size(1)
except NameError:
return 0
return self.growing_beams.size(1)
def _init_beam_state(self, batch_size):
""" (re-)Initialize the state of the beams. """
......@@ -120,7 +122,7 @@ class BeamSearch(nn.Module):
self._step = 0
self.is_done = False
def forward(self, encoder_input_ids, **model_kwargs):
def __call__(self, encoder_input_ids, **model_kwargs):
""" Generate a sequence using Beam Search. """
# keyword arguments come in 3 flavors: encoder-specific (prefixed by
# `encoder_`), decoder-specific (prefixed by `decoder_`) and those
......@@ -158,28 +160,17 @@ class BeamSearch(nn.Module):
kwargs_encoder["attention_mask"], self.beam_size, dim=0
)
# grow the beam by generating sequences in an autoregressive way
# grow the beam iteratively
batch_size, block_size = encoder_input_ids.size()
self._init_beam_state(batch_size)
for step in range(self.max_length):
# Add padding tokens
decoder_input = torch.full(
(self.growing_beams.size(0), block_size),
self.pad_token_id,
dtype=torch.long,
device=self.growing_beams.device,
)
decoder_input[:, : self.growing_beams.size(1)] = self.growing_beams
# compute decoder_attention_mask
decoder_mask = torch.ones_like(decoder_input)
idx_pad_tokens = decoder_input == self.pad_token_id
decoder_mask[idx_pad_tokens] = 0
kwargs_decoder["attention_mask"] = decoder_mask
decoder_input = fit_to_block_size(self.growing_beams, block_size, self.pad_token_id)
kwargs_decoder["attention_mask"] = build_mask(decoder_input)
outputs = self.model.decoder(decoder_input, **kwargs_decoder)
last_token_scores = outputs[0][:, -1, :].squeeze(1)
log_probabilities = torch.nn.functional.log_softmax(last_token_scores, dim=0)
next_token_scores = outputs[0][:, -1, :].squeeze(1)
log_probabilities = torch.nn.functional.log_softmax(next_token_scores, dim=0)
surviving_beams_rows = self.grow(log_probabilities)
if self.is_done:
break
......@@ -356,20 +347,14 @@ def fit_to_block_size(sequence, block_size, pad_token_id):
""" Adapt the source and target sequences' lengths to the block size.
If the sequence is shorter we append padding tokens to the right.
"""
if len(sequence) > block_size:
return sequence[:block_size]
else:
return torch.cat(
(sequence, torch.tensor([pad_token_id] * (block_size - len(sequence)))), dim=0
)
def build_lm_labels(sequence, pad_token_id):
""" Padding token, encoded as 0, are represented by the value -1 so they
are not taken into account in the loss computation. """
padded = sequence.clone()
padded[padded == pad_token_id] = -1
return padded
padded_sequence = torch.full(
(sequence.size(0), block_size),
pad_token_id,
dtype=torch.long,
device=sequence.device,
)
padded_sequence[:, : sequence.size(1)] = sequence
return sequence
def build_mask(sequence, pad_token_id):
......
from collections import namedtuple
import unittest
import pytest
import numpy as np
import torch
from torch import nn
from transformers.generate import BeamSearch
from transformers import PreTrainedEncoderDecoder
StubTokenizer = namedtuple("Tokenizer", ["bos_token_id", "eos_token_id", "pad_token_id"])
StubTransformer = namedtuple("Transformer", ["encoder", "decoder"])
class StubTransformer(nn.Module):
def __init__(self):
self.encoder = None
self.decoder = None
self._parameters = {"dumy": torch.tensor([1])}
def forward(self):
pass
class BeamSearchtest(unittest.TestCase):
......@@ -18,12 +25,13 @@ class BeamSearchtest(unittest.TestCase):
class will break the integration with the beam search.
"""
model = PreTrainedEncoderDecoder("encoder", "decoder")
tokenizer = StubTokenizer(0, 1, 2)
model = StubTransformer()
try:
_ = BeamSearch(
model=model,
tokenizer=tokenizer,
bos_token_id=0,
eos_token_id=1,
pad_token_id=2,
batch_size=1,
beam_size=1,
min_length=1,
......@@ -46,8 +54,10 @@ class BeamSearchtest(unittest.TestCase):
min_length = 5
beam = BeamSearch(
model=StubTransformer("encoder", "decoder"),
tokenizer=StubTokenizer(bos_token_id=0, eos_token_id=eos_idx, pad_token_id=2),
model=StubTransformer(),
bos_token_id=0,
eos_token_id=eos_idx,
pad_token_id=2,
batch_size=batch_size,
beam_size=beam_size,
min_length=5,
......@@ -71,17 +81,17 @@ class BeamSearchtest(unittest.TestCase):
log_probabilities[0, eos_idx] = score_distribution[0]
for idx, score in zip(non_eos_idxs, score_distribution[1:]):
log_probabilities[0, idx] = score
pytest.set_trace()
for step in range(1, min_length + 2):
log_probabilities[0, eos_idx] = score_distribution[0]
# Beam #3 and #4 teminate at the first step since the probability
# of the [EOS] token is -1e20 > -\infty so there are only two beams left.
# The top beam (most likely) always ends with 4 until we reach min_length.
surviving_beams_rows = beam.grow(log_probabilities)
if step < min_length:
np.testing.assert_array_equal(
beam.growing_beams.numpy(),
np.repeat(np.array([[0] + [4] * step]), 2, axis=0),
beam.growing_beams.numpy()[0, :], np.array([0] + [4] * step)
)
elif step == min_length:
np.testing.assert_array_equal(surviving_beams_rows.numpy(), np.array([]))
......@@ -99,8 +109,10 @@ class BeamSearchtest(unittest.TestCase):
vocab_size = 10
beam = BeamSearch(
model=StubTransformer("encoder", "decoder"),
tokenizer=StubTokenizer(bos_token_id=0, eos_token_id=1, pad_token_id=2),
model=StubTransformer(),
bos_token_id=0,
eos_token_id=1,
pad_token_id=2,
batch_size=batch_size,
beam_size=beam_size,
min_length=2,
......@@ -140,8 +152,10 @@ class BeamSearchtest(unittest.TestCase):
vocab_size = 10
beam = BeamSearch(
model=StubTransformer("encoder", "decoder"),
tokenizer=StubTokenizer(bos_token_id=0, eos_token_id=1, pad_token_id=2),
model=StubTransformer(),
bos_token_id=0,
eos_token_id=1,
pad_token_id=2,
batch_size=batch_size,
beam_size=beam_size,
min_length=2,
......@@ -167,7 +181,6 @@ class BeamSearchtest(unittest.TestCase):
log_probabilities[::beam_size, idx] = score
surviving_beams_rows = beam.grow(log_probabilities)
log_probabilities = log_probabilities.index_select(0, surviving_beams_rows)
if step < 7:
self.assertFalse(
......@@ -182,6 +195,8 @@ class BeamSearchtest(unittest.TestCase):
np.array([-1e20] * vocab_size, dtype="float32"),
)
log_probabilities = log_probabilities.index_select(0, surviving_beams_rows)
def test_beam_search_example_for_one_step(self):
""" We test that the predictions for one step of growth are correct. """
batch_size = 2
......@@ -190,8 +205,10 @@ class BeamSearchtest(unittest.TestCase):
vocab_size = 5
beam = BeamSearch(
model=StubTransformer("encoder", "decoder"),
tokenizer=StubTokenizer(bos_token_id=0, eos_token_id=1, pad_token_id=2),
model=StubTransformer(),
bos_token_id=0,
eos_token_id=1,
pad_token_id=2,
batch_size=batch_size,
beam_size=beam_size,
min_length=2,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment