Commit ba089c78 authored by Rémi Louf's avatar Rémi Louf Committed by Julien Chaumond
Browse files

share pretrained embeddings

parent 9660ba1c
......@@ -136,18 +136,11 @@ def encode_for_summarization(story_lines, summary_lines, tokenizer):
as specified in [1] by using `[SEP] [CLS]` tokens to separate
sentences.
"""
story_lines_token_ids = [
tokenizer.build_inputs_with_special_tokens(tokenizer.encode(line))
for line in story_lines
]
summary_lines_token_ids = [
tokenizer.build_inputs_with_special_tokens(tokenizer.encode(line))
for line in summary_lines
]
story_lines_token_ids = [tokenizer.encode(line) for line in story_lines]
story_token_ids = [
token for sentence in story_lines_token_ids for token in sentence
]
summary_lines_token_ids = [tokenizer.encode(line) for line in summary_lines]
summary_token_ids = [
token for sentence in summary_lines_token_ids for token in sentence
]
......
......@@ -10,3 +10,5 @@ regex
sentencepiece
# For XLM
sacremoses
# For ROUGE
pyrouge
......@@ -26,27 +26,31 @@ Use Beam Search to generate sequences using encoder-decoder models.
import torch
from torch import nn
import logging
logger = logging.getLogger(__name__)
class BeamSearch(nn.Module):
def __init__(
self,
model,
tokenizer,
bos_token_id,
pad_token_id,
eos_token_id,
batch_size,
beam_size,
min_length,
max_length,
batch_size=1,
alpha=0,
block_repeating_trigrams=True,
device=torch.device("cpu"),
):
r"""
Inputs:
**model**: instance of ``transformers.PreTrainedEncoderDecoder``
The pretrained encoder-decoder model that will be used to generate the sequences.
**tokenizer**: instance of ``transformers.PreTrainedTokenizer``
The pretrained tokenizer associated to the model used in the encoder-decoder. We only
support encoder-decoder that use the same tokenizer for encoder and decoder. The tokenizer
needs to be initialized or this function will raise and exception.
**batch_size**: (`optional`) int
Batch size of the inputs. The value is set automatically when calling `forward`.
**beam_size**: int
......@@ -64,11 +68,11 @@ class BeamSearch(nn.Module):
"""
super(BeamSearch, self).__init__()
self.model = model
self.tokenizer = tokenizer
self.device = device
self.bos_token_id = tokenizer.bos_token_id
self.eos_token_id = tokenizer.eos_token_id
self.pad_token_id = tokenizer.pad_token_id
self.bos_token_id = bos_token_id
self.eos_token_id = eos_token_id
self.pad_token_id = pad_token_id
self.batch_size = batch_size
self.beam_size = beam_size
......@@ -90,15 +94,24 @@ class BeamSearch(nn.Module):
def _init_beam_state(self, batch_size):
""" (re-)Initialize the state of the beams. """
self.hypotheses = [[] for _ in range(batch_size)]
self.batch_offset = torch.arange(batch_size, dtype=torch.long)
self.batch_offset = torch.arange(batch_size, dtype=torch.long, device=self.device)
self.beam_offset = torch.arange(
0, batch_size * self.beam_size, step=self.beam_size, dtype=torch.long
0,
batch_size * self.beam_size,
step=self.beam_size,
dtype=torch.long,
device=self.device,
)
self.growing_beams = torch.full(
(batch_size * self.beam_size, 1), self.bos_token_id, dtype=torch.long
(batch_size * self.beam_size, 1),
self.bos_token_id,
dtype=torch.long,
device=self.device,
)
self.topk_log_probabilities = torch.tensor(
[0.0] + [float("-inf")] * (self.beam_size - 1), dtype=torch.float
[0.0] + [float("-inf")] * (self.beam_size - 1),
dtype=torch.float,
device=self.device,
).repeat(batch_size)
self.results = {
"predictions": [[] for _ in range(batch_size)],
......@@ -136,28 +149,37 @@ class BeamSearch(nn.Module):
)
# forward pass on the encoder
encoder_outputs = self.model.encoder.forward(encoder_input_ids, kwargs_encoder)
encoder_outputs = self.model.encoder(encoder_input_ids, **kwargs_encoder)
encoder_hidden_states = encoder_outputs[0]
kwargs_decoder["encoder_hidden_states"] = tile(
encoder_outputs, self.beam_size, dim=0
encoder_hidden_states, self.beam_size, dim=0
)
kwargs_decoder["encoder_attention_mask"] = tile(
kwargs_encoder["attention_mask"], self.beam_size, dim=0
)
# grow the beam by generating sequences in an autoregressive way
batch_size = encoder_input_ids.size(0)
batch_size, block_size = encoder_input_ids.size()
self._init_beam_state(batch_size)
for step in range(self.max_length):
# prepare the decoder input
decoder_input = fit_to_block_size(
self.growing_beams, self.tokenizer.pad_token_id
)
kwargs_decoder["decoder_lm_labels"] = build_lm_labels(
decoder_input, self.tokenizer.pad_token_id
)
kwargs_decoder["decoder_attention_mask"] = build_mask(
decoder_input, self.tokenizer.pad_token_id
# Add padding tokens
decoder_input = torch.full(
(self.growing_beams.size(0), block_size),
self.pad_token_id,
dtype=torch.long,
device=self.growing_beams.device,
)
decoder_input[:, : self.growing_beams.size(1)] = self.growing_beams
outputs = self.model.decoder(decoder_input, kwargs_decoder)
log_probabilities = torch.nn.functional.log_softmax(outputs[1])
# compute decoder_attention_mask
decoder_mask = torch.ones_like(decoder_input)
idx_pad_tokens = decoder_input == self.pad_token_id
decoder_mask[idx_pad_tokens] = 0
kwargs_decoder["attention_mask"] = decoder_mask
outputs = self.model.decoder(decoder_input, **kwargs_decoder)
last_token_scores = outputs[0][:, -1, :].squeeze(1)
log_probabilities = torch.nn.functional.log_softmax(last_token_scores, dim=0)
surviving_beams_rows = self.grow(log_probabilities)
if self.is_done:
break
......@@ -189,13 +211,13 @@ class BeamSearch(nn.Module):
# Find the `beam_size` (previous_beam + token) combinations with
# the highest score
topk_log_probabilities, topk_ids = torch.topk(
self.topk_log_probabilities, topk_ids = torch.topk(
log_probabilities.view(_B, self.beam_size * vocab_size), self.beam_size, dim=1
)
# Apply the length penalty. The +1 accounts for the [EOS] token
# that will be added if the beam ends.
topk_scores = topk_log_probabilities
topk_scores = self.topk_log_probabilities
if self.apply_length_penalty:
topk_scores /= self._length_penalty()
......@@ -337,8 +359,9 @@ def fit_to_block_size(sequence, block_size, pad_token_id):
if len(sequence) > block_size:
return sequence[:block_size]
else:
sequence.extend([pad_token_id] * (block_size - len(sequence)))
return sequence
return torch.cat(
(sequence, torch.tensor([pad_token_id] * (block_size - len(sequence)))), dim=0
)
def build_lm_labels(sequence, pad_token_id):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment