Unverified Commit 006097f8 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

rename variables named 'word' to 'token' in generate fn (#3119)

* fix conflits

* fixed naming bug

* make style
parent 71c87119
......@@ -242,7 +242,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
# initialize all new embeddings (in particular added tokens)
self._init_weights(new_embeddings)
# Copy word embeddings from the previous weights
# Copy token embeddings from the previous weights
num_tokens_to_copy = min(old_num_tokens, new_num_tokens)
new_embeddings.weight.data[:num_tokens_to_copy, :] = old_embeddings.weight.data[:num_tokens_to_copy, :]
......@@ -558,7 +558,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
model.__class__.__name__, "\n\t".join(error_msgs)
)
)
model.tie_weights() # make sure word embedding weights are still tied if needed
model.tie_weights() # make sure token embedding weights are still tied if needed
# Set model in evaluation mode to desactivate DropOut modules by default
model.eval()
......@@ -843,8 +843,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
""" Generate sequences for each example without beam search (num_beams == 1).
All returned sequence are generated independantly.
"""
# current position / max lengths / length of generated sentences / unfinished sentences
# length of generated sentences / unfinished sentences
unfinished_sents = input_ids.new(batch_size).fill_(1)
sent_lengths = input_ids.new(batch_size).fill_(max_length)
......@@ -934,7 +933,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
"""
# Expand input to num beams
# assert input_ids.shape == (batch_size * num_beams, cur_len)
input_ids = input_ids.unsqueeze(1).expand(batch_size, num_beams, cur_len)
input_ids = input_ids.contiguous().view(batch_size * num_beams, cur_len) # (batch_size * num_beams, cur_len)
......@@ -946,7 +944,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
# scores for each sentence in the beam
beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)
# Greedy decoding it is made sure that only words of the first beam are considered to avoid sampling the exact same words three times
# Greedy decoding it is made sure that only tokens of the first beam are considered to avoid sampling the exact same tokens three times
if do_sample is False:
beam_scores[:, 1:] = -1e9
beam_scores = beam_scores.view(-1) # shape (batch_size * num_beams,)
......@@ -960,7 +958,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
while cur_len < max_length:
model_inputs = self.prepare_inputs_for_generation(input_ids, past=past)
outputs = self(**model_inputs) # (batch_size * num_beams, cur_len, vocab_size)
scores = outputs[0][:, -1, :] # (batch_size * num_beams, vocab_size)
next_token_logits = outputs[0][:, -1, :] # (batch_size * num_beams, vocab_size)
# if model has past, then set the past variable to speed up decoding
if self._do_output_past(outputs):
......@@ -968,14 +966,16 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
# repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858)
if repetition_penalty != 1.0:
self.enforce_repetition_penalty_(scores, batch_size, num_beams, input_ids, repetition_penalty)
self.enforce_repetition_penalty_(
next_token_logits, batch_size, num_beams, input_ids, repetition_penalty
)
if do_sample:
# Temperature (higher temperature => more likely to sample low probability tokens)
if temperature != 1.0:
scores = scores / temperature
next_token_logits = next_token_logits / temperature
scores = F.log_softmax(scores, dim=-1) # (batch_size * num_beams, vocab_size)
scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size)
_scores = scores + beam_scores[:, None].expand_as(scores) # (batch_size * num_beams, vocab_size)
# Top-p/top-k filtering
......@@ -988,25 +988,27 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
batch_size, num_beams * vocab_size
) # (batch_size, num_beams * vocab_size)
# Sample 2 next words for each beam (so we have some spare tokens and match output of greedy beam search)
next_words = torch.multinomial(
# Sample 2 next tokens for each beam (so we have some spare tokens and match output of greedy beam search)
next_tokens = torch.multinomial(
F.softmax(_scores, dim=-1), num_samples=2 * num_beams
) # (batch_size, num_beams * 2)
# Compute next scores
next_scores = torch.gather(_scores, -1, next_words) # (batch_size, num_beams * 2)
next_scores = torch.gather(_scores, -1, next_tokens) # (batch_size, num_beams * 2)
else:
# do greedy beam search
scores = F.log_softmax(scores, dim=-1) # (batch_size * num_beams, vocab_size)
scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size)
assert scores.size() == (batch_size * num_beams, vocab_size)
# Add the log prob of the new beams to the log prob of the beginning of the sequence (sum of logs == log of the product)
_scores = scores + beam_scores[:, None].expand_as(scores) # (batch_size * num_beams, vocab_size)
next_scores = scores + beam_scores[:, None].expand_as(scores) # (batch_size * num_beams, vocab_size)
# re-organize to group the beam together (we are keeping top hypothesis accross beams)
_scores = _scores.view(batch_size, num_beams * vocab_size) # (batch_size, num_beams * vocab_size)
next_scores, next_words = torch.topk(_scores, 2 * num_beams, dim=1, largest=True, sorted=True)
next_scores = next_scores.view(
batch_size, num_beams * vocab_size
) # (batch_size, num_beams * vocab_size)
next_scores, next_tokens = torch.topk(next_scores, 2 * num_beams, dim=1, largest=True, sorted=True)
assert next_scores.size() == next_words.size() == (batch_size, 2 * num_beams)
assert next_scores.size() == next_tokens.size() == (batch_size, 2 * num_beams)
# next batch beam content
# list of (batch_size * num_beams) tuple(next hypothesis score, next word, current position in the batch)
......@@ -1032,21 +1034,21 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
# next sentence beam content
next_sent_beam = []
# next words for this sentence
for idx, score in zip(next_words[batch_idx], next_scores[batch_idx]):
# next tokens for this sentence
for idx, score in zip(next_tokens[batch_idx], next_scores[batch_idx]):
# get beam and word IDs
beam_id = idx // vocab_size
word_id = idx % vocab_size
token_id = idx % vocab_size
# add to generated hypotheses if end of sentence or last iteration
if eos_token_ids is not None and word_id.item() in eos_token_ids:
if eos_token_ids is not None and token_id.item() in eos_token_ids:
generated_hyps[batch_idx].add(
input_ids[batch_idx * num_beams + beam_id, :cur_len].clone(), score.item(),
)
else:
# add next predicted word if it is not eos_token
next_sent_beam.append((score, word_id, batch_idx * num_beams + beam_id))
next_sent_beam.append((score, token_id, batch_idx * num_beams + beam_id))
# the beam for next step is full
if len(next_sent_beam) == num_beams:
......@@ -1060,12 +1062,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
# sanity check / prepare next batch
assert len(next_batch_beam) == batch_size * num_beams
beam_scores = beam_scores.new([x[0] for x in next_batch_beam])
beam_words = input_ids.new([x[1] for x in next_batch_beam])
beam_tokens = input_ids.new([x[1] for x in next_batch_beam])
beam_idx = input_ids.new([x[2] for x in next_batch_beam])
# re-order batch
input_ids = input_ids[beam_idx, :]
input_ids = torch.cat([input_ids, beam_words.unsqueeze(1)], dim=-1)
input_ids = torch.cat([input_ids, beam_tokens.unsqueeze(1)], dim=-1)
# re-order internal states
if past:
......@@ -1081,11 +1083,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
for batch_idx in range(batch_size):
# Add all open beam hypothesis to generated_hyps
if not done[batch_idx]:
for idx, score in zip(next_words[batch_idx], next_scores[batch_idx]):
for idx, score in zip(next_tokens[batch_idx], next_scores[batch_idx]):
# get beam and word IDs
beam_id = idx // vocab_size
word_id = idx % vocab_size
token_id = idx % vocab_size
generated_hyps[batch_idx].add(
input_ids[batch_idx * num_beams + beam_id, :cur_len].clone(), score.item()
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment