Commit 4735c2af authored by Rémi Louf's avatar Rémi Louf Committed by Julien Chaumond
Browse files

tweaks to the BeamSearch API

parent ba089c78
...@@ -32,7 +32,7 @@ import logging ...@@ -32,7 +32,7 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class BeamSearch(nn.Module): class BeamSearch(object):
def __init__( def __init__(
self, self,
model, model,
...@@ -45,12 +45,17 @@ class BeamSearch(nn.Module): ...@@ -45,12 +45,17 @@ class BeamSearch(nn.Module):
max_length, max_length,
alpha=0, alpha=0,
block_repeating_trigrams=True, block_repeating_trigrams=True,
device=torch.device("cpu"),
): ):
r""" r"""
Inputs: Inputs:
**model**: instance of ``transformers.PreTrainedEncoderDecoder`` **model**: instance of ``transformers.PreTrainedEncoderDecoder``
The pretrained encoder-decoder model that will be used to generate the sequences. The pretrained encoder-decoder model that will be used to generate the sequences.
**bos_token_id**: int
Id that is used by the tokenizer to represent the beggining of a sentence.
**pad_token_id**: int
Id that is used by the tokenizer for padding.
**eos_token_id**: int
Id that is used by the tokenizer to represent the end of a sentence.
**batch_size**: (`optional`) int **batch_size**: (`optional`) int
Batch size of the inputs. The value is set automatically when calling `forward`. Batch size of the inputs. The value is set automatically when calling `forward`.
**beam_size**: int **beam_size**: int
...@@ -68,7 +73,7 @@ class BeamSearch(nn.Module): ...@@ -68,7 +73,7 @@ class BeamSearch(nn.Module):
""" """
super(BeamSearch, self).__init__() super(BeamSearch, self).__init__()
self.model = model self.model = model
self.device = device self.device = next(model.parameters()).device # only works if all parameters of the model are stored on a single GPU
self.bos_token_id = bos_token_id self.bos_token_id = bos_token_id
self.eos_token_id = eos_token_id self.eos_token_id = eos_token_id
...@@ -86,10 +91,7 @@ class BeamSearch(nn.Module): ...@@ -86,10 +91,7 @@ class BeamSearch(nn.Module):
self._init_beam_state(batch_size) self._init_beam_state(batch_size)
def __len__(self): def __len__(self):
try:
return self.growing_beams.size(1) return self.growing_beams.size(1)
except NameError:
return 0
def _init_beam_state(self, batch_size): def _init_beam_state(self, batch_size):
""" (re-)Initialize the state of the beams. """ """ (re-)Initialize the state of the beams. """
...@@ -120,7 +122,7 @@ class BeamSearch(nn.Module): ...@@ -120,7 +122,7 @@ class BeamSearch(nn.Module):
self._step = 0 self._step = 0
self.is_done = False self.is_done = False
def forward(self, encoder_input_ids, **model_kwargs): def __call__(self, encoder_input_ids, **model_kwargs):
""" Generate a sequence using Beam Search. """ """ Generate a sequence using Beam Search. """
# keyword arguments come in 3 flavors: encoder-specific (prefixed by # keyword arguments come in 3 flavors: encoder-specific (prefixed by
# `encoder_`), decoder-specific (prefixed by `decoder_`) and those # `encoder_`), decoder-specific (prefixed by `decoder_`) and those
...@@ -158,28 +160,17 @@ class BeamSearch(nn.Module): ...@@ -158,28 +160,17 @@ class BeamSearch(nn.Module):
kwargs_encoder["attention_mask"], self.beam_size, dim=0 kwargs_encoder["attention_mask"], self.beam_size, dim=0
) )
# grow the beam by generating sequences in an autoregressive way # grow the beam iteratively
batch_size, block_size = encoder_input_ids.size() batch_size, block_size = encoder_input_ids.size()
self._init_beam_state(batch_size) self._init_beam_state(batch_size)
for step in range(self.max_length): for step in range(self.max_length):
# Add padding tokens
decoder_input = torch.full(
(self.growing_beams.size(0), block_size),
self.pad_token_id,
dtype=torch.long,
device=self.growing_beams.device,
)
decoder_input[:, : self.growing_beams.size(1)] = self.growing_beams
# compute decoder_attention_mask
decoder_mask = torch.ones_like(decoder_input)
idx_pad_tokens = decoder_input == self.pad_token_id
decoder_mask[idx_pad_tokens] = 0
kwargs_decoder["attention_mask"] = decoder_mask
decoder_input = fit_to_block_size(self.growing_beams, block_size, self.pad_token_id)
kwargs_decoder["attention_mask"] = build_mask(decoder_input)
outputs = self.model.decoder(decoder_input, **kwargs_decoder) outputs = self.model.decoder(decoder_input, **kwargs_decoder)
last_token_scores = outputs[0][:, -1, :].squeeze(1)
log_probabilities = torch.nn.functional.log_softmax(last_token_scores, dim=0) next_token_scores = outputs[0][:, -1, :].squeeze(1)
log_probabilities = torch.nn.functional.log_softmax(next_token_scores, dim=0)
surviving_beams_rows = self.grow(log_probabilities) surviving_beams_rows = self.grow(log_probabilities)
if self.is_done: if self.is_done:
break break
...@@ -356,20 +347,14 @@ def fit_to_block_size(sequence, block_size, pad_token_id): ...@@ -356,20 +347,14 @@ def fit_to_block_size(sequence, block_size, pad_token_id):
""" Adapt the source and target sequences' lengths to the block size. """ Adapt the source and target sequences' lengths to the block size.
If the sequence is shorter we append padding tokens to the right. If the sequence is shorter we append padding tokens to the right.
""" """
if len(sequence) > block_size: padded_sequence = torch.full(
return sequence[:block_size] (sequence.size(0), block_size),
else: pad_token_id,
return torch.cat( dtype=torch.long,
(sequence, torch.tensor([pad_token_id] * (block_size - len(sequence)))), dim=0 device=sequence.device,
) )
padded_sequence[:, : sequence.size(1)] = sequence
return sequence
def build_lm_labels(sequence, pad_token_id):
""" Padding token, encoded as 0, are represented by the value -1 so they
are not taken into account in the loss computation. """
padded = sequence.clone()
padded[padded == pad_token_id] = -1
return padded
def build_mask(sequence, pad_token_id): def build_mask(sequence, pad_token_id):
......
from collections import namedtuple from collections import namedtuple
import unittest import unittest
import pytest
import numpy as np import numpy as np
import torch import torch
from torch import nn
from transformers.generate import BeamSearch from transformers.generate import BeamSearch
from transformers import PreTrainedEncoderDecoder from transformers import PreTrainedEncoderDecoder
StubTokenizer = namedtuple("Tokenizer", ["bos_token_id", "eos_token_id", "pad_token_id"]) class StubTransformer(nn.Module):
StubTransformer = namedtuple("Transformer", ["encoder", "decoder"]) def __init__(self):
self.encoder = None
self.decoder = None
self._parameters = {"dumy": torch.tensor([1])}
def forward(self):
pass
class BeamSearchtest(unittest.TestCase): class BeamSearchtest(unittest.TestCase):
...@@ -18,12 +25,13 @@ class BeamSearchtest(unittest.TestCase): ...@@ -18,12 +25,13 @@ class BeamSearchtest(unittest.TestCase):
class will break the integration with the beam search. class will break the integration with the beam search.
""" """
model = PreTrainedEncoderDecoder("encoder", "decoder") model = StubTransformer()
tokenizer = StubTokenizer(0, 1, 2)
try: try:
_ = BeamSearch( _ = BeamSearch(
model=model, model=model,
tokenizer=tokenizer, bos_token_id=0,
eos_token_id=1,
pad_token_id=2,
batch_size=1, batch_size=1,
beam_size=1, beam_size=1,
min_length=1, min_length=1,
...@@ -46,8 +54,10 @@ class BeamSearchtest(unittest.TestCase): ...@@ -46,8 +54,10 @@ class BeamSearchtest(unittest.TestCase):
min_length = 5 min_length = 5
beam = BeamSearch( beam = BeamSearch(
model=StubTransformer("encoder", "decoder"), model=StubTransformer(),
tokenizer=StubTokenizer(bos_token_id=0, eos_token_id=eos_idx, pad_token_id=2), bos_token_id=0,
eos_token_id=eos_idx,
pad_token_id=2,
batch_size=batch_size, batch_size=batch_size,
beam_size=beam_size, beam_size=beam_size,
min_length=5, min_length=5,
...@@ -71,17 +81,17 @@ class BeamSearchtest(unittest.TestCase): ...@@ -71,17 +81,17 @@ class BeamSearchtest(unittest.TestCase):
log_probabilities[0, eos_idx] = score_distribution[0] log_probabilities[0, eos_idx] = score_distribution[0]
for idx, score in zip(non_eos_idxs, score_distribution[1:]): for idx, score in zip(non_eos_idxs, score_distribution[1:]):
log_probabilities[0, idx] = score log_probabilities[0, idx] = score
pytest.set_trace()
for step in range(1, min_length + 2): for step in range(1, min_length + 2):
log_probabilities[0, eos_idx] = score_distribution[0] log_probabilities[0, eos_idx] = score_distribution[0]
# Beam #3 and #4 teminate at the first step since the probability # Beam #3 and #4 teminate at the first step since the probability
# of the [EOS] token is -1e20 > -\infty so there are only two beams left. # of the [EOS] token is -1e20 > -\infty so there are only two beams left.
# The top beam (most likely) always ends with 4 until we reach min_length.
surviving_beams_rows = beam.grow(log_probabilities) surviving_beams_rows = beam.grow(log_probabilities)
if step < min_length: if step < min_length:
np.testing.assert_array_equal( np.testing.assert_array_equal(
beam.growing_beams.numpy(), beam.growing_beams.numpy()[0, :], np.array([0] + [4] * step)
np.repeat(np.array([[0] + [4] * step]), 2, axis=0),
) )
elif step == min_length: elif step == min_length:
np.testing.assert_array_equal(surviving_beams_rows.numpy(), np.array([])) np.testing.assert_array_equal(surviving_beams_rows.numpy(), np.array([]))
...@@ -99,8 +109,10 @@ class BeamSearchtest(unittest.TestCase): ...@@ -99,8 +109,10 @@ class BeamSearchtest(unittest.TestCase):
vocab_size = 10 vocab_size = 10
beam = BeamSearch( beam = BeamSearch(
model=StubTransformer("encoder", "decoder"), model=StubTransformer(),
tokenizer=StubTokenizer(bos_token_id=0, eos_token_id=1, pad_token_id=2), bos_token_id=0,
eos_token_id=1,
pad_token_id=2,
batch_size=batch_size, batch_size=batch_size,
beam_size=beam_size, beam_size=beam_size,
min_length=2, min_length=2,
...@@ -140,8 +152,10 @@ class BeamSearchtest(unittest.TestCase): ...@@ -140,8 +152,10 @@ class BeamSearchtest(unittest.TestCase):
vocab_size = 10 vocab_size = 10
beam = BeamSearch( beam = BeamSearch(
model=StubTransformer("encoder", "decoder"), model=StubTransformer(),
tokenizer=StubTokenizer(bos_token_id=0, eos_token_id=1, pad_token_id=2), bos_token_id=0,
eos_token_id=1,
pad_token_id=2,
batch_size=batch_size, batch_size=batch_size,
beam_size=beam_size, beam_size=beam_size,
min_length=2, min_length=2,
...@@ -167,7 +181,6 @@ class BeamSearchtest(unittest.TestCase): ...@@ -167,7 +181,6 @@ class BeamSearchtest(unittest.TestCase):
log_probabilities[::beam_size, idx] = score log_probabilities[::beam_size, idx] = score
surviving_beams_rows = beam.grow(log_probabilities) surviving_beams_rows = beam.grow(log_probabilities)
log_probabilities = log_probabilities.index_select(0, surviving_beams_rows)
if step < 7: if step < 7:
self.assertFalse( self.assertFalse(
...@@ -182,6 +195,8 @@ class BeamSearchtest(unittest.TestCase): ...@@ -182,6 +195,8 @@ class BeamSearchtest(unittest.TestCase):
np.array([-1e20] * vocab_size, dtype="float32"), np.array([-1e20] * vocab_size, dtype="float32"),
) )
log_probabilities = log_probabilities.index_select(0, surviving_beams_rows)
def test_beam_search_example_for_one_step(self): def test_beam_search_example_for_one_step(self):
""" We test that the predictions for one step of growth are correct. """ """ We test that the predictions for one step of growth are correct. """
batch_size = 2 batch_size = 2
...@@ -190,8 +205,10 @@ class BeamSearchtest(unittest.TestCase): ...@@ -190,8 +205,10 @@ class BeamSearchtest(unittest.TestCase):
vocab_size = 5 vocab_size = 5
beam = BeamSearch( beam = BeamSearch(
model=StubTransformer("encoder", "decoder"), model=StubTransformer(),
tokenizer=StubTokenizer(bos_token_id=0, eos_token_id=1, pad_token_id=2), bos_token_id=0,
eos_token_id=1,
pad_token_id=2,
batch_size=batch_size, batch_size=batch_size,
beam_size=beam_size, beam_size=beam_size,
min_length=2, min_length=2,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment