Commit ca2047bc authored by Patrick von Platen's avatar Patrick von Platen
Browse files

refactor variable naming and improve tf generate in line with torch generate

parent 41b437ea
......@@ -459,6 +459,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
self,
input_ids=None,
max_length=None,
min_length=None,
do_sample=True,
early_stopping=False,
num_beams=None,
......@@ -470,7 +471,9 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
pad_token_id=None,
eos_token_ids=None,
length_penalty=None,
no_repeat_ngram_size=None,
num_return_sequences=None,
attention_mask=None,
):
r""" Generates sequences for models with a LM head. The method currently supports greedy or penalized greedy decoding, sampling with top-k or nucleus sampling
and beam-search.
......@@ -564,6 +567,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
)
max_length = max_length if max_length is not None else self.config.max_length
min_length = min_length if min_length is not None else self.config.min_length
do_sample = do_sample if do_sample is not None else self.config.do_sample
early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping
num_beams = num_beams if num_beams is not None else self.config.num_beams
......@@ -575,6 +579,9 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
eos_token_ids = eos_token_ids if eos_token_ids is not None else self.config.eos_token_ids
length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty
no_repeat_ngram_size = (
no_repeat_ngram_size if no_repeat_ngram_size is not None else self.config.no_repeat_ngram_size
)
num_return_sequences = (
num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences
)
......@@ -587,6 +594,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
eos_token_ids = [eos_token_ids]
assert isinstance(max_length, int) and max_length > 0, "`max_length` should be a strictely positive integer."
assert isinstance(min_length, int) and min_length >= 0, "`min_length` should be a positive integer."
assert isinstance(do_sample, bool), "`do_sample` should be a boolean."
assert isinstance(early_stopping, bool), "`early_stopping` should be a boolean."
assert isinstance(num_beams, int) and num_beams > 0, "`num_beams` should be a strictely positive integer."
......@@ -631,6 +639,13 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
num_beams >= num_return_sequences
), "Greedy beam search decoding cannot return more sequences than it has beams. Please set num_beams >= num_return_sequences"
# create attention mask if necessary
# TODO (PVP): this should later be handled by the forward fn() in each model in the future see PR 3140
if (attention_mask is None) and (pad_token_id is not None) and (pad_token_id in input_ids):
attention_mask = tf.cast(tf.math.not_equal(input_ids, pad_token_id), dtype=tf.int32)
elif attention_mask is None:
attention_mask = tf.ones_like(input_ids)
if pad_token_id is None and eos_token_ids is not None:
logger.warning(
"Setting `pad_token_id` to {} (first `eos_token_id`) to generate sequence".format(eos_token_ids[0])
......@@ -655,42 +670,55 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
input_ids = tf.broadcast_to(
tf.expand_dims(input_ids, 1), (batch_size, effective_batch_mult * num_beams, input_ids_len)
)
attention_mask = tf.broadcast_to(
tf.expand_dims(attention_mask, 1), (batch_size, effective_batch_mult * num_beams, input_ids_len)
)
input_ids = tf.reshape(
input_ids, (effective_batch_size * num_beams, input_ids_len)
) # shape: (batch_size * num_return_sequences * num_beams, cur_len)
attention_mask = tf.reshape(
attention_mask, (effective_batch_size * num_beams, input_ids_len)
) # shape: (batch_size * num_return_sequences * num_beams, cur_len)
if num_beams > 1:
output = self._generate_beam_search(
input_ids,
cur_len,
max_length,
do_sample,
early_stopping,
temperature,
top_k,
top_p,
repetition_penalty,
pad_token_id,
eos_token_ids,
effective_batch_size,
num_return_sequences,
length_penalty,
num_beams,
vocab_size,
cur_len=cur_len,
max_length=max_length,
min_length=min_length,
do_sample=do_sample,
early_stopping=early_stopping,
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
no_repeat_ngram_size=no_repeat_ngram_size,
pad_token_id=pad_token_id,
eos_token_ids=eos_token_ids,
batch_size=effective_batch_size,
num_return_sequences=num_return_sequences,
length_penalty=length_penalty,
num_beams=num_beams,
vocab_size=vocab_size,
attention_mask=attention_mask,
)
else:
output = self._generate_no_beam_search(
input_ids,
cur_len,
max_length,
do_sample,
temperature,
top_k,
top_p,
repetition_penalty,
pad_token_id,
eos_token_ids,
effective_batch_size,
cur_len=cur_len,
max_length=max_length,
min_length=min_length,
do_sample=do_sample,
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
no_repeat_ngram_size=no_repeat_ngram_size,
pad_token_id=pad_token_id,
eos_token_ids=eos_token_ids,
batch_size=effective_batch_size,
vocab_size=vocab_size,
attention_mask=attention_mask,
)
return output
......@@ -700,14 +728,18 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
input_ids,
cur_len,
max_length,
min_length,
do_sample,
temperature,
top_k,
top_p,
repetition_penalty,
no_repeat_ngram_size,
pad_token_id,
eos_token_ids,
batch_size,
vocab_size,
attention_mask,
):
""" Generate sequences for each example without beam search (num_beams == 1).
All returned sequence are generated independantly.
......@@ -720,7 +752,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
past = None
while cur_len < max_length:
model_inputs = self.prepare_inputs_for_generation(input_ids, past=past)
model_inputs = self.prepare_inputs_for_generation(input_ids, past=past, attention_mask=attention_mask)
outputs = self(**model_inputs)
next_token_logits = outputs[0][:, -1, :]
......@@ -735,6 +767,33 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
)
next_token_logits = tf.math.multiply(next_token_logits, next_token_logits_penalties)
if no_repeat_ngram_size > 0:
# calculate a list of banned tokens to prevent repetitively generating the same ngrams
# from fairseq: https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345
banned_tokens = calc_banned_tokens(input_ids, batch_size, no_repeat_ngram_size, cur_len)
# create banned_tokens boolean mask
banned_tokens_indices_mask = []
for banned_tokens_slice in banned_tokens:
banned_tokens_indices_mask.append(
[True if token in banned_tokens_slice else False for token in range(vocab_size)]
)
next_token_logits = set_tensor_by_indices_to_value(
next_token_logits, tf.convert_to_tensor(banned_tokens_indices_mask, dtype=tf.bool), -float("inf")
)
# set eos token prob to zero if min_length is not reached
if eos_token_ids is not None and cur_len < min_length:
# create eos_token_ids boolean mask
is_token_logit_eos_token = tf.convert_to_tensor(
[True if token in eos_token_ids else False for token in range(vocab_size)], dtype=tf.bool
)
eos_token_indices_mask = tf.broadcast_to(is_token_logit_eos_token, [batch_size, vocab_size])
next_token_logits = set_tensor_by_indices_to_value(
next_token_logits, eos_token_indices_mask, -float("inf")
)
if do_sample:
# Temperature (higher temperature => more likely to sample low probability tokens)
if temperature != 1.0:
......@@ -806,12 +865,14 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
input_ids,
cur_len,
max_length,
min_length,
do_sample,
early_stopping,
temperature,
top_k,
top_p,
repetition_penalty,
no_repeat_ngram_size,
pad_token_id,
eos_token_ids,
batch_size,
......@@ -819,6 +880,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
length_penalty,
num_beams,
vocab_size,
attention_mask,
):
""" Generate sequences for each example with beam search.
"""
......@@ -829,7 +891,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
for _ in range(batch_size)
]
# scores for each sentence in the beam
# for greedy decoding it is made sure that only tokens of the first beam are considered to avoid sampling the exact same tokens three times
if do_sample is False:
beam_scores_begin = tf.zeros((batch_size, 1), dtype=tf.float32)
beam_scores_end = tf.zeros((batch_size, num_beams - 1), dtype=tf.float32) * 1e-9
......@@ -845,7 +907,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
done = [False for _ in range(batch_size)]
while cur_len < max_length:
model_inputs = self.prepare_inputs_for_generation(input_ids, past=past)
model_inputs = self.prepare_inputs_for_generation(input_ids, past=past, attention_mask=attention_mask)
outputs = self(**model_inputs) # (batch_size * num_beams, cur_len, vocab_size)
next_token_logits = outputs[0][:, -1, :] # (batch_size * num_beams, vocab_size)
......@@ -860,12 +922,42 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
)
next_token_logits = tf.math.multiply(next_token_logits, next_token_logits_penalties)
if do_sample:
# Temperature (higher temperature => more likely to sample low probability tokens)
if temperature != 1.0:
next_token_logits = next_token_logits / temperature
if no_repeat_ngram_size > 0:
# calculate a list of banned tokens to prevent repetitively generating the same ngrams
# from fairseq: https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345
banned_tokens = calc_banned_tokens(input_ids, batch_size, no_repeat_ngram_size, cur_len)
# create banned_tokens boolean mask
banned_tokens_indices_mask = []
for banned_tokens_slice in banned_tokens:
banned_tokens_indices_mask.append(
[True if token in banned_tokens_slice else False for token in range(vocab_size)]
)
next_token_logits = set_tensor_by_indices_to_value(
next_token_logits, tf.convert_to_tensor(banned_tokens_indices_mask, dtype=tf.bool), -float("inf")
)
# set eos token prob to zero if min_length is not reached
if eos_token_ids is not None and cur_len < min_length:
# create eos_token_ids boolean mask
is_token_logit_eos_token = tf.convert_to_tensor(
[True if token in eos_token_ids else False for token in range(vocab_size)], dtype=tf.bool
)
eos_token_indices_mask = tf.broadcast_to(is_token_logit_eos_token, [batch_size, vocab_size])
next_token_logits = set_tensor_by_indices_to_value(
next_token_logits, eos_token_indices_mask, -float("inf")
)
# calculate log softmax score
scores = tf.nn.log_softmax(next_token_logits, axis=-1) # (batch_size * num_beams, vocab_size)
assert shape_list(scores) == [batch_size * num_beams, vocab_size]
if do_sample:
_scores = scores + tf.broadcast_to(
beam_scores[:, None], (batch_size * num_beams, vocab_size)
) # (batch_size * num_beams, vocab_size)
......@@ -888,9 +980,6 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
next_scores = tf.gather(next_scores, next_scores_indices, batch_dims=1) # (batch_size, num_beams * 2)
next_tokens = tf.gather(next_tokens, next_scores_indices, batch_dims=1) # (batch_size, num_beams * 2)
else:
# do greedy beam search
scores = tf.nn.log_softmax(next_token_logits, axis=-1) # (batch_size * num_beams, vocab_size)
assert shape_list(scores) == [batch_size * num_beams, vocab_size]
# Add the log prob of the new beams to the log prob of the beginning of the sequence (sum of logs == log of the product)
next_scores = scores + tf.broadcast_to(
beam_scores[:, None], (batch_size * num_beams, vocab_size)
......@@ -912,10 +1001,6 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
# for each sentence
for batch_idx in range(batch_size):
# if we are done with this sentence
done[batch_idx] = done[batch_idx] or generated_hyps[batch_idx].is_done(
tf.reduce_max(next_scores[batch_idx]).numpy()
)
if done[batch_idx]:
assert (
len(generated_hyps[batch_idx]) >= num_beams
......@@ -930,29 +1015,46 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
next_sent_beam = []
# next tokens for this sentence
for idx, score in zip(next_tokens[batch_idx], next_scores[batch_idx]):
for beam_token_rank, (beam_token_id, beam_token_score) in enumerate(
zip(next_tokens[batch_idx], next_scores[batch_idx])
):
# get beam and token IDs
beam_id = idx // vocab_size
token_id = idx % vocab_size
beam_id = beam_token_id // vocab_size
token_id = beam_token_id % vocab_size
effective_beam_id = batch_idx * num_beams + beam_id
# add to generated hypotheses if end of sentence or last iteration
if eos_token_ids is not None and token_id.numpy() in eos_token_ids:
generated_hyps[batch_idx].add(tf.identity(input_ids[effective_beam_id]), score.numpy())
# if beam_token does not belong to top num_beams tokens, it should not be added
is_beam_token_worse_than_top_num_beams = beam_token_rank >= num_beams
if is_beam_token_worse_than_top_num_beams:
continue
generated_hyps[batch_idx].add(
tf.identity(input_ids[effective_beam_id]), beam_token_score.numpy()
)
else:
# add next predicted token if it is not eos_token
next_sent_beam.append((score, token_id, effective_beam_id))
next_sent_beam.append((beam_token_score, token_id, effective_beam_id))
# the beam for next step is full
if len(next_sent_beam) == num_beams:
break
# if we are done with this sentence
done[batch_idx] = done[batch_idx] or generated_hyps[batch_idx].is_done(
tf.reduce_max(next_scores[batch_idx]).numpy()
)
# update next beam content
assert len(next_sent_beam) == num_beams, "Beam should always be full"
next_batch_beam.extend(next_sent_beam)
assert len(next_batch_beam) == num_beams * (batch_idx + 1)
# stop when we are done with each sentence
if all(done):
break
# sanity check / prepare next batch
assert len(next_batch_beam) == batch_size * num_beams
beam_scores = tf.convert_to_tensor([x[0] for x in next_batch_beam], dtype=tf.float32)
......@@ -967,10 +1069,6 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
if past:
past = self._reorder_cache(past, beam_idx)
# stop when we are done with each sentence
if all(done):
break
# update current length
cur_len = cur_len + 1
......@@ -1072,6 +1170,29 @@ def _create_next_token_logits_penalties(input_ids, logits, repetition_penalty):
return tf.convert_to_tensor(token_penalties, dtype=tf.float32)
def calc_banned_tokens(prev_input_ids, num_hypos, no_repeat_ngram_size, cur_len):
# Copied from fairseq for no_repeat_ngram in beam_search"""
if cur_len + 1 < no_repeat_ngram_size:
# return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet
return [[] for _ in range(num_hypos)]
generated_ngrams = [{} for _ in range(num_hypos)]
for idx in range(num_hypos):
gen_tokens = prev_input_ids[idx].numpy().tolist()
generated_ngram = generated_ngrams[idx]
for ngram in zip(*[gen_tokens[i:] for i in range(no_repeat_ngram_size)]):
prev_ngram_tuple = tuple(ngram[:-1])
generated_ngram[prev_ngram_tuple] = generated_ngram.get(prev_ngram_tuple, []) + [ngram[-1]]
def _get_generated_ngrams(hypo_idx):
# Before decoding the next token, prevent decoding of ngrams that have already appeared
start_idx = cur_len + 1 - no_repeat_ngram_size
ngram_idx = tuple(prev_input_ids[hypo_idx, start_idx:cur_len].numpy().tolist())
return generated_ngrams[hypo_idx].get(ngram_idx, [])
banned_tokens = [_get_generated_ngrams(hypo_idx) for hypo_idx in range(num_hypos)]
return banned_tokens
def tf_top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1):
""" Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
Args:
......
......@@ -722,6 +722,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
max_length = max_length if max_length is not None else self.config.max_length
min_length = min_length if min_length is not None else self.config.min_length
do_sample = do_sample if do_sample is not None else self.config.do_sample
early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping
num_beams = num_beams if num_beams is not None else self.config.num_beams
temperature = temperature if temperature is not None else self.config.temperature
......@@ -852,6 +853,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
device=next(self.parameters()).device,
)
cur_len = 1
# put model in generation mode if it has one
if hasattr(self.model, "generation_mode"):
self.model.decoder.generation_mode = True
else:
encoder_inputs = None
......@@ -860,44 +864,44 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
if num_beams > 1:
output = self._generate_beam_search(
input_ids,
cur_len,
max_length,
min_length,
do_sample,
early_stopping,
temperature,
top_k,
top_p,
repetition_penalty,
no_repeat_ngram_size,
bos_token_id,
pad_token_id,
eos_token_ids,
effective_batch_size,
num_return_sequences,
length_penalty,
num_beams,
vocab_size,
encoder_inputs,
attention_mask,
cur_len=cur_len,
max_length=max_length,
min_length=min_length,
do_sample=do_sample,
early_stopping=early_stopping,
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
no_repeat_ngram_size=no_repeat_ngram_size,
bos_token_id=bos_token_id,
pad_token_id=pad_token_id,
eos_token_ids=eos_token_ids,
batch_size=effective_batch_size,
num_return_sequences=num_return_sequences,
length_penalty=length_penalty,
num_beams=num_beams,
vocab_size=vocab_size,
encoder_inputs=encoder_inputs,
attention_mask=attention_mask,
)
else:
output = self._generate_no_beam_search(
input_ids,
cur_len,
max_length,
min_length,
do_sample,
temperature,
top_k,
top_p,
repetition_penalty,
no_repeat_ngram_size,
pad_token_id,
eos_token_ids,
effective_batch_size,
encoder_inputs,
attention_mask,
cur_len=cur_len,
max_length=max_length,
min_length=min_length,
do_sample=do_sample,
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
no_repeat_ngram_size=no_repeat_ngram_size,
pad_token_id=pad_token_id,
eos_token_ids=eos_token_ids,
batch_size=effective_batch_size,
encoder_inputs=encoder_inputs,
attention_mask=attention_mask,
)
return output
......@@ -1157,24 +1161,27 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
next_sent_beam = []
# next tokens for this sentence
for i, (idx, score) in enumerate(zip(next_tokens[batch_idx], next_scores[batch_idx])):
for beam_token_rank, (beam_token_id, beam_token_score) in enumerate(
zip(next_tokens[batch_idx], next_scores[batch_idx])
):
# get beam and word IDs
beam_id = idx // vocab_size
token_id = idx % vocab_size
beam_id = beam_token_id // vocab_size
token_id = beam_token_id % vocab_size
effective_beam_id = batch_idx * num_beams + beam_id
# add to generated hypotheses if end of sentence
if (eos_token_ids is not None) and (token_id.item() in eos_token_ids):
# when passed to num_beams hypotheses, continue
if i >= num_beams:
# if beam_token does not belong to top num_beams tokens, it should not be added
is_beam_token_worse_than_top_num_beams = beam_token_rank >= num_beams
if is_beam_token_worse_than_top_num_beams:
continue
generated_hyps[batch_idx].add(
input_ids[effective_beam_id].clone(), score.item(),
input_ids[effective_beam_id].clone(), beam_token_score.item(),
)
else:
# add next predicted word if it is not eos_token
next_sent_beam.append((score, token_id, effective_beam_id))
next_sent_beam.append((beam_token_score, token_id, effective_beam_id))
# the beam for next step is full
if len(next_sent_beam) == num_beams:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment