"...git@developer.sourcefind.cn:cnjsdfcy/simbricks.git" did not exist on "d9ae8200be9f14669828c8a94af2a2cef7d1e36e"
Commit ca2047bc authored by Patrick von Platen's avatar Patrick von Platen
Browse files

refactor variable naming and improve tf generate in line with torch generate

parent 41b437ea
...@@ -459,6 +459,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): ...@@ -459,6 +459,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
self, self,
input_ids=None, input_ids=None,
max_length=None, max_length=None,
min_length=None,
do_sample=True, do_sample=True,
early_stopping=False, early_stopping=False,
num_beams=None, num_beams=None,
...@@ -470,7 +471,9 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): ...@@ -470,7 +471,9 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
pad_token_id=None, pad_token_id=None,
eos_token_ids=None, eos_token_ids=None,
length_penalty=None, length_penalty=None,
no_repeat_ngram_size=None,
num_return_sequences=None, num_return_sequences=None,
attention_mask=None,
): ):
r""" Generates sequences for models with a LM head. The method currently supports greedy or penalized greedy decoding, sampling with top-k or nucleus sampling r""" Generates sequences for models with a LM head. The method currently supports greedy or penalized greedy decoding, sampling with top-k or nucleus sampling
and beam-search. and beam-search.
...@@ -564,6 +567,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): ...@@ -564,6 +567,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
) )
max_length = max_length if max_length is not None else self.config.max_length max_length = max_length if max_length is not None else self.config.max_length
min_length = min_length if min_length is not None else self.config.min_length
do_sample = do_sample if do_sample is not None else self.config.do_sample do_sample = do_sample if do_sample is not None else self.config.do_sample
early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping
num_beams = num_beams if num_beams is not None else self.config.num_beams num_beams = num_beams if num_beams is not None else self.config.num_beams
...@@ -575,6 +579,9 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): ...@@ -575,6 +579,9 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
eos_token_ids = eos_token_ids if eos_token_ids is not None else self.config.eos_token_ids eos_token_ids = eos_token_ids if eos_token_ids is not None else self.config.eos_token_ids
length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty
no_repeat_ngram_size = (
no_repeat_ngram_size if no_repeat_ngram_size is not None else self.config.no_repeat_ngram_size
)
num_return_sequences = ( num_return_sequences = (
num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences
) )
...@@ -587,6 +594,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): ...@@ -587,6 +594,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
eos_token_ids = [eos_token_ids] eos_token_ids = [eos_token_ids]
assert isinstance(max_length, int) and max_length > 0, "`max_length` should be a strictely positive integer." assert isinstance(max_length, int) and max_length > 0, "`max_length` should be a strictely positive integer."
assert isinstance(min_length, int) and min_length >= 0, "`min_length` should be a positive integer."
assert isinstance(do_sample, bool), "`do_sample` should be a boolean." assert isinstance(do_sample, bool), "`do_sample` should be a boolean."
assert isinstance(early_stopping, bool), "`early_stopping` should be a boolean." assert isinstance(early_stopping, bool), "`early_stopping` should be a boolean."
assert isinstance(num_beams, int) and num_beams > 0, "`num_beams` should be a strictely positive integer." assert isinstance(num_beams, int) and num_beams > 0, "`num_beams` should be a strictely positive integer."
...@@ -631,6 +639,13 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): ...@@ -631,6 +639,13 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
num_beams >= num_return_sequences num_beams >= num_return_sequences
), "Greedy beam search decoding cannot return more sequences than it has beams. Please set num_beams >= num_return_sequences" ), "Greedy beam search decoding cannot return more sequences than it has beams. Please set num_beams >= num_return_sequences"
# create attention mask if necessary
# TODO (PVP): this should later be handled by the forward fn() in each model in the future see PR 3140
if (attention_mask is None) and (pad_token_id is not None) and (pad_token_id in input_ids):
attention_mask = tf.cast(tf.math.not_equal(input_ids, pad_token_id), dtype=tf.int32)
elif attention_mask is None:
attention_mask = tf.ones_like(input_ids)
if pad_token_id is None and eos_token_ids is not None: if pad_token_id is None and eos_token_ids is not None:
logger.warning( logger.warning(
"Setting `pad_token_id` to {} (first `eos_token_id`) to generate sequence".format(eos_token_ids[0]) "Setting `pad_token_id` to {} (first `eos_token_id`) to generate sequence".format(eos_token_ids[0])
...@@ -655,42 +670,55 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): ...@@ -655,42 +670,55 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
input_ids = tf.broadcast_to( input_ids = tf.broadcast_to(
tf.expand_dims(input_ids, 1), (batch_size, effective_batch_mult * num_beams, input_ids_len) tf.expand_dims(input_ids, 1), (batch_size, effective_batch_mult * num_beams, input_ids_len)
) )
attention_mask = tf.broadcast_to(
tf.expand_dims(attention_mask, 1), (batch_size, effective_batch_mult * num_beams, input_ids_len)
)
input_ids = tf.reshape( input_ids = tf.reshape(
input_ids, (effective_batch_size * num_beams, input_ids_len) input_ids, (effective_batch_size * num_beams, input_ids_len)
) # shape: (batch_size * num_return_sequences * num_beams, cur_len) ) # shape: (batch_size * num_return_sequences * num_beams, cur_len)
attention_mask = tf.reshape(
attention_mask, (effective_batch_size * num_beams, input_ids_len)
) # shape: (batch_size * num_return_sequences * num_beams, cur_len)
if num_beams > 1: if num_beams > 1:
output = self._generate_beam_search( output = self._generate_beam_search(
input_ids, input_ids,
cur_len, cur_len=cur_len,
max_length, max_length=max_length,
do_sample, min_length=min_length,
early_stopping, do_sample=do_sample,
temperature, early_stopping=early_stopping,
top_k, temperature=temperature,
top_p, top_k=top_k,
repetition_penalty, top_p=top_p,
pad_token_id, repetition_penalty=repetition_penalty,
eos_token_ids, no_repeat_ngram_size=no_repeat_ngram_size,
effective_batch_size, pad_token_id=pad_token_id,
num_return_sequences, eos_token_ids=eos_token_ids,
length_penalty, batch_size=effective_batch_size,
num_beams, num_return_sequences=num_return_sequences,
vocab_size, length_penalty=length_penalty,
num_beams=num_beams,
vocab_size=vocab_size,
attention_mask=attention_mask,
) )
else: else:
output = self._generate_no_beam_search( output = self._generate_no_beam_search(
input_ids, input_ids,
cur_len, cur_len=cur_len,
max_length, max_length=max_length,
do_sample, min_length=min_length,
temperature, do_sample=do_sample,
top_k, temperature=temperature,
top_p, top_k=top_k,
repetition_penalty, top_p=top_p,
pad_token_id, repetition_penalty=repetition_penalty,
eos_token_ids, no_repeat_ngram_size=no_repeat_ngram_size,
effective_batch_size, pad_token_id=pad_token_id,
eos_token_ids=eos_token_ids,
batch_size=effective_batch_size,
vocab_size=vocab_size,
attention_mask=attention_mask,
) )
return output return output
...@@ -700,14 +728,18 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): ...@@ -700,14 +728,18 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
input_ids, input_ids,
cur_len, cur_len,
max_length, max_length,
min_length,
do_sample, do_sample,
temperature, temperature,
top_k, top_k,
top_p, top_p,
repetition_penalty, repetition_penalty,
no_repeat_ngram_size,
pad_token_id, pad_token_id,
eos_token_ids, eos_token_ids,
batch_size, batch_size,
vocab_size,
attention_mask,
): ):
""" Generate sequences for each example without beam search (num_beams == 1). """ Generate sequences for each example without beam search (num_beams == 1).
All returned sequence are generated independantly. All returned sequence are generated independantly.
...@@ -720,7 +752,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): ...@@ -720,7 +752,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
past = None past = None
while cur_len < max_length: while cur_len < max_length:
model_inputs = self.prepare_inputs_for_generation(input_ids, past=past) model_inputs = self.prepare_inputs_for_generation(input_ids, past=past, attention_mask=attention_mask)
outputs = self(**model_inputs) outputs = self(**model_inputs)
next_token_logits = outputs[0][:, -1, :] next_token_logits = outputs[0][:, -1, :]
...@@ -735,6 +767,33 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): ...@@ -735,6 +767,33 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
) )
next_token_logits = tf.math.multiply(next_token_logits, next_token_logits_penalties) next_token_logits = tf.math.multiply(next_token_logits, next_token_logits_penalties)
if no_repeat_ngram_size > 0:
# calculate a list of banned tokens to prevent repetitively generating the same ngrams
# from fairseq: https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345
banned_tokens = calc_banned_tokens(input_ids, batch_size, no_repeat_ngram_size, cur_len)
# create banned_tokens boolean mask
banned_tokens_indices_mask = []
for banned_tokens_slice in banned_tokens:
banned_tokens_indices_mask.append(
[True if token in banned_tokens_slice else False for token in range(vocab_size)]
)
next_token_logits = set_tensor_by_indices_to_value(
next_token_logits, tf.convert_to_tensor(banned_tokens_indices_mask, dtype=tf.bool), -float("inf")
)
# set eos token prob to zero if min_length is not reached
if eos_token_ids is not None and cur_len < min_length:
# create eos_token_ids boolean mask
is_token_logit_eos_token = tf.convert_to_tensor(
[True if token in eos_token_ids else False for token in range(vocab_size)], dtype=tf.bool
)
eos_token_indices_mask = tf.broadcast_to(is_token_logit_eos_token, [batch_size, vocab_size])
next_token_logits = set_tensor_by_indices_to_value(
next_token_logits, eos_token_indices_mask, -float("inf")
)
if do_sample: if do_sample:
# Temperature (higher temperature => more likely to sample low probability tokens) # Temperature (higher temperature => more likely to sample low probability tokens)
if temperature != 1.0: if temperature != 1.0:
...@@ -806,12 +865,14 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): ...@@ -806,12 +865,14 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
input_ids, input_ids,
cur_len, cur_len,
max_length, max_length,
min_length,
do_sample, do_sample,
early_stopping, early_stopping,
temperature, temperature,
top_k, top_k,
top_p, top_p,
repetition_penalty, repetition_penalty,
no_repeat_ngram_size,
pad_token_id, pad_token_id,
eos_token_ids, eos_token_ids,
batch_size, batch_size,
...@@ -819,6 +880,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): ...@@ -819,6 +880,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
length_penalty, length_penalty,
num_beams, num_beams,
vocab_size, vocab_size,
attention_mask,
): ):
""" Generate sequences for each example with beam search. """ Generate sequences for each example with beam search.
""" """
...@@ -829,7 +891,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): ...@@ -829,7 +891,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
for _ in range(batch_size) for _ in range(batch_size)
] ]
# scores for each sentence in the beam # for greedy decoding it is made sure that only tokens of the first beam are considered to avoid sampling the exact same tokens three times
if do_sample is False: if do_sample is False:
beam_scores_begin = tf.zeros((batch_size, 1), dtype=tf.float32) beam_scores_begin = tf.zeros((batch_size, 1), dtype=tf.float32)
beam_scores_end = tf.zeros((batch_size, num_beams - 1), dtype=tf.float32) * 1e-9 beam_scores_end = tf.zeros((batch_size, num_beams - 1), dtype=tf.float32) * 1e-9
...@@ -845,7 +907,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): ...@@ -845,7 +907,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
done = [False for _ in range(batch_size)] done = [False for _ in range(batch_size)]
while cur_len < max_length: while cur_len < max_length:
model_inputs = self.prepare_inputs_for_generation(input_ids, past=past) model_inputs = self.prepare_inputs_for_generation(input_ids, past=past, attention_mask=attention_mask)
outputs = self(**model_inputs) # (batch_size * num_beams, cur_len, vocab_size) outputs = self(**model_inputs) # (batch_size * num_beams, cur_len, vocab_size)
next_token_logits = outputs[0][:, -1, :] # (batch_size * num_beams, vocab_size) next_token_logits = outputs[0][:, -1, :] # (batch_size * num_beams, vocab_size)
...@@ -860,12 +922,42 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): ...@@ -860,12 +922,42 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
) )
next_token_logits = tf.math.multiply(next_token_logits, next_token_logits_penalties) next_token_logits = tf.math.multiply(next_token_logits, next_token_logits_penalties)
if do_sample: # Temperature (higher temperature => more likely to sample low probability tokens)
# Temperature (higher temperature => more likely to sample low probability tokens) if temperature != 1.0:
if temperature != 1.0: next_token_logits = next_token_logits / temperature
next_token_logits = next_token_logits / temperature
if no_repeat_ngram_size > 0:
# calculate a list of banned tokens to prevent repetitively generating the same ngrams
# from fairseq: https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345
banned_tokens = calc_banned_tokens(input_ids, batch_size, no_repeat_ngram_size, cur_len)
# create banned_tokens boolean mask
banned_tokens_indices_mask = []
for banned_tokens_slice in banned_tokens:
banned_tokens_indices_mask.append(
[True if token in banned_tokens_slice else False for token in range(vocab_size)]
)
next_token_logits = set_tensor_by_indices_to_value(
next_token_logits, tf.convert_to_tensor(banned_tokens_indices_mask, dtype=tf.bool), -float("inf")
)
# set eos token prob to zero if min_length is not reached
if eos_token_ids is not None and cur_len < min_length:
# create eos_token_ids boolean mask
is_token_logit_eos_token = tf.convert_to_tensor(
[True if token in eos_token_ids else False for token in range(vocab_size)], dtype=tf.bool
)
eos_token_indices_mask = tf.broadcast_to(is_token_logit_eos_token, [batch_size, vocab_size])
next_token_logits = set_tensor_by_indices_to_value(
next_token_logits, eos_token_indices_mask, -float("inf")
)
# calculate log softmax score
scores = tf.nn.log_softmax(next_token_logits, axis=-1) # (batch_size * num_beams, vocab_size)
assert shape_list(scores) == [batch_size * num_beams, vocab_size]
scores = tf.nn.log_softmax(next_token_logits, axis=-1) # (batch_size * num_beams, vocab_size) if do_sample:
_scores = scores + tf.broadcast_to( _scores = scores + tf.broadcast_to(
beam_scores[:, None], (batch_size * num_beams, vocab_size) beam_scores[:, None], (batch_size * num_beams, vocab_size)
) # (batch_size * num_beams, vocab_size) ) # (batch_size * num_beams, vocab_size)
...@@ -888,9 +980,6 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): ...@@ -888,9 +980,6 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
next_scores = tf.gather(next_scores, next_scores_indices, batch_dims=1) # (batch_size, num_beams * 2) next_scores = tf.gather(next_scores, next_scores_indices, batch_dims=1) # (batch_size, num_beams * 2)
next_tokens = tf.gather(next_tokens, next_scores_indices, batch_dims=1) # (batch_size, num_beams * 2) next_tokens = tf.gather(next_tokens, next_scores_indices, batch_dims=1) # (batch_size, num_beams * 2)
else: else:
# do greedy beam search
scores = tf.nn.log_softmax(next_token_logits, axis=-1) # (batch_size * num_beams, vocab_size)
assert shape_list(scores) == [batch_size * num_beams, vocab_size]
# Add the log prob of the new beams to the log prob of the beginning of the sequence (sum of logs == log of the product) # Add the log prob of the new beams to the log prob of the beginning of the sequence (sum of logs == log of the product)
next_scores = scores + tf.broadcast_to( next_scores = scores + tf.broadcast_to(
beam_scores[:, None], (batch_size * num_beams, vocab_size) beam_scores[:, None], (batch_size * num_beams, vocab_size)
...@@ -912,10 +1001,6 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): ...@@ -912,10 +1001,6 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
# for each sentence # for each sentence
for batch_idx in range(batch_size): for batch_idx in range(batch_size):
# if we are done with this sentence
done[batch_idx] = done[batch_idx] or generated_hyps[batch_idx].is_done(
tf.reduce_max(next_scores[batch_idx]).numpy()
)
if done[batch_idx]: if done[batch_idx]:
assert ( assert (
len(generated_hyps[batch_idx]) >= num_beams len(generated_hyps[batch_idx]) >= num_beams
...@@ -930,29 +1015,46 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): ...@@ -930,29 +1015,46 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
next_sent_beam = [] next_sent_beam = []
# next tokens for this sentence # next tokens for this sentence
for idx, score in zip(next_tokens[batch_idx], next_scores[batch_idx]): for beam_token_rank, (beam_token_id, beam_token_score) in enumerate(
zip(next_tokens[batch_idx], next_scores[batch_idx])
):
# get beam and token IDs # get beam and token IDs
beam_id = idx // vocab_size beam_id = beam_token_id // vocab_size
token_id = idx % vocab_size token_id = beam_token_id % vocab_size
effective_beam_id = batch_idx * num_beams + beam_id effective_beam_id = batch_idx * num_beams + beam_id
# add to generated hypotheses if end of sentence or last iteration # add to generated hypotheses if end of sentence or last iteration
if eos_token_ids is not None and token_id.numpy() in eos_token_ids: if eos_token_ids is not None and token_id.numpy() in eos_token_ids:
generated_hyps[batch_idx].add(tf.identity(input_ids[effective_beam_id]), score.numpy()) # if beam_token does not belong to top num_beams tokens, it should not be added
is_beam_token_worse_than_top_num_beams = beam_token_rank >= num_beams
if is_beam_token_worse_than_top_num_beams:
continue
generated_hyps[batch_idx].add(
tf.identity(input_ids[effective_beam_id]), beam_token_score.numpy()
)
else: else:
# add next predicted token if it is not eos_token # add next predicted token if it is not eos_token
next_sent_beam.append((score, token_id, effective_beam_id)) next_sent_beam.append((beam_token_score, token_id, effective_beam_id))
# the beam for next step is full # the beam for next step is full
if len(next_sent_beam) == num_beams: if len(next_sent_beam) == num_beams:
break break
# if we are done with this sentence
done[batch_idx] = done[batch_idx] or generated_hyps[batch_idx].is_done(
tf.reduce_max(next_scores[batch_idx]).numpy()
)
# update next beam content # update next beam content
assert len(next_sent_beam) == num_beams, "Beam should always be full" assert len(next_sent_beam) == num_beams, "Beam should always be full"
next_batch_beam.extend(next_sent_beam) next_batch_beam.extend(next_sent_beam)
assert len(next_batch_beam) == num_beams * (batch_idx + 1) assert len(next_batch_beam) == num_beams * (batch_idx + 1)
# stop when we are done with each sentence
if all(done):
break
# sanity check / prepare next batch # sanity check / prepare next batch
assert len(next_batch_beam) == batch_size * num_beams assert len(next_batch_beam) == batch_size * num_beams
beam_scores = tf.convert_to_tensor([x[0] for x in next_batch_beam], dtype=tf.float32) beam_scores = tf.convert_to_tensor([x[0] for x in next_batch_beam], dtype=tf.float32)
...@@ -967,10 +1069,6 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): ...@@ -967,10 +1069,6 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
if past: if past:
past = self._reorder_cache(past, beam_idx) past = self._reorder_cache(past, beam_idx)
# stop when we are done with each sentence
if all(done):
break
# update current length # update current length
cur_len = cur_len + 1 cur_len = cur_len + 1
...@@ -1072,6 +1170,29 @@ def _create_next_token_logits_penalties(input_ids, logits, repetition_penalty): ...@@ -1072,6 +1170,29 @@ def _create_next_token_logits_penalties(input_ids, logits, repetition_penalty):
return tf.convert_to_tensor(token_penalties, dtype=tf.float32) return tf.convert_to_tensor(token_penalties, dtype=tf.float32)
def calc_banned_tokens(prev_input_ids, num_hypos, no_repeat_ngram_size, cur_len):
# Copied from fairseq for no_repeat_ngram in beam_search"""
if cur_len + 1 < no_repeat_ngram_size:
# return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet
return [[] for _ in range(num_hypos)]
generated_ngrams = [{} for _ in range(num_hypos)]
for idx in range(num_hypos):
gen_tokens = prev_input_ids[idx].numpy().tolist()
generated_ngram = generated_ngrams[idx]
for ngram in zip(*[gen_tokens[i:] for i in range(no_repeat_ngram_size)]):
prev_ngram_tuple = tuple(ngram[:-1])
generated_ngram[prev_ngram_tuple] = generated_ngram.get(prev_ngram_tuple, []) + [ngram[-1]]
def _get_generated_ngrams(hypo_idx):
# Before decoding the next token, prevent decoding of ngrams that have already appeared
start_idx = cur_len + 1 - no_repeat_ngram_size
ngram_idx = tuple(prev_input_ids[hypo_idx, start_idx:cur_len].numpy().tolist())
return generated_ngrams[hypo_idx].get(ngram_idx, [])
banned_tokens = [_get_generated_ngrams(hypo_idx) for hypo_idx in range(num_hypos)]
return banned_tokens
def tf_top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1): def tf_top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1):
""" Filter a distribution of logits using top-k and/or nucleus (top-p) filtering """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
Args: Args:
......
...@@ -722,6 +722,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ...@@ -722,6 +722,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
max_length = max_length if max_length is not None else self.config.max_length max_length = max_length if max_length is not None else self.config.max_length
min_length = min_length if min_length is not None else self.config.min_length min_length = min_length if min_length is not None else self.config.min_length
do_sample = do_sample if do_sample is not None else self.config.do_sample
early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping
num_beams = num_beams if num_beams is not None else self.config.num_beams num_beams = num_beams if num_beams is not None else self.config.num_beams
temperature = temperature if temperature is not None else self.config.temperature temperature = temperature if temperature is not None else self.config.temperature
...@@ -852,7 +853,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ...@@ -852,7 +853,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
device=next(self.parameters()).device, device=next(self.parameters()).device,
) )
cur_len = 1 cur_len = 1
self.model.decoder.generation_mode = True
# put model in generation mode if it has one
if hasattr(self.model, "generation_mode"):
self.model.decoder.generation_mode = True
else: else:
encoder_inputs = None encoder_inputs = None
cur_len = input_ids.shape[-1] cur_len = input_ids.shape[-1]
...@@ -860,44 +864,44 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ...@@ -860,44 +864,44 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
if num_beams > 1: if num_beams > 1:
output = self._generate_beam_search( output = self._generate_beam_search(
input_ids, input_ids,
cur_len, cur_len=cur_len,
max_length, max_length=max_length,
min_length, min_length=min_length,
do_sample, do_sample=do_sample,
early_stopping, early_stopping=early_stopping,
temperature, temperature=temperature,
top_k, top_k=top_k,
top_p, top_p=top_p,
repetition_penalty, repetition_penalty=repetition_penalty,
no_repeat_ngram_size, no_repeat_ngram_size=no_repeat_ngram_size,
bos_token_id, bos_token_id=bos_token_id,
pad_token_id, pad_token_id=pad_token_id,
eos_token_ids, eos_token_ids=eos_token_ids,
effective_batch_size, batch_size=effective_batch_size,
num_return_sequences, num_return_sequences=num_return_sequences,
length_penalty, length_penalty=length_penalty,
num_beams, num_beams=num_beams,
vocab_size, vocab_size=vocab_size,
encoder_inputs, encoder_inputs=encoder_inputs,
attention_mask, attention_mask=attention_mask,
) )
else: else:
output = self._generate_no_beam_search( output = self._generate_no_beam_search(
input_ids, input_ids,
cur_len, cur_len=cur_len,
max_length, max_length=max_length,
min_length, min_length=min_length,
do_sample, do_sample=do_sample,
temperature, temperature=temperature,
top_k, top_k=top_k,
top_p, top_p=top_p,
repetition_penalty, repetition_penalty=repetition_penalty,
no_repeat_ngram_size, no_repeat_ngram_size=no_repeat_ngram_size,
pad_token_id, pad_token_id=pad_token_id,
eos_token_ids, eos_token_ids=eos_token_ids,
effective_batch_size, batch_size=effective_batch_size,
encoder_inputs, encoder_inputs=encoder_inputs,
attention_mask, attention_mask=attention_mask,
) )
return output return output
...@@ -1157,24 +1161,27 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ...@@ -1157,24 +1161,27 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
next_sent_beam = [] next_sent_beam = []
# next tokens for this sentence # next tokens for this sentence
for i, (idx, score) in enumerate(zip(next_tokens[batch_idx], next_scores[batch_idx])): for beam_token_rank, (beam_token_id, beam_token_score) in enumerate(
zip(next_tokens[batch_idx], next_scores[batch_idx])
):
# get beam and word IDs # get beam and word IDs
beam_id = idx // vocab_size beam_id = beam_token_id // vocab_size
token_id = idx % vocab_size token_id = beam_token_id % vocab_size
effective_beam_id = batch_idx * num_beams + beam_id effective_beam_id = batch_idx * num_beams + beam_id
# add to generated hypotheses if end of sentence # add to generated hypotheses if end of sentence
if (eos_token_ids is not None) and (token_id.item() in eos_token_ids): if (eos_token_ids is not None) and (token_id.item() in eos_token_ids):
# when passed to num_beams hypotheses, continue # if beam_token does not belong to top num_beams tokens, it should not be added
if i >= num_beams: is_beam_token_worse_than_top_num_beams = beam_token_rank >= num_beams
if is_beam_token_worse_than_top_num_beams:
continue continue
generated_hyps[batch_idx].add( generated_hyps[batch_idx].add(
input_ids[effective_beam_id].clone(), score.item(), input_ids[effective_beam_id].clone(), beam_token_score.item(),
) )
else: else:
# add next predicted word if it is not eos_token # add next predicted word if it is not eos_token
next_sent_beam.append((score, token_id, effective_beam_id)) next_sent_beam.append((beam_token_score, token_id, effective_beam_id))
# the beam for next step is full # the beam for next step is full
if len(next_sent_beam) == num_beams: if len(next_sent_beam) == num_beams:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment