Unverified Commit b38d552a authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Generate] Add bad words list argument to the generate function (#3367)

* add bad words list

* make style

* add bad_words_tokens

* make style

* better naming

* make style

* fix typo
parent ae6834e0
...@@ -80,6 +80,7 @@ class PretrainedConfig(object): ...@@ -80,6 +80,7 @@ class PretrainedConfig(object):
self.repetition_penalty = kwargs.pop("repetition_penalty", 1.0) self.repetition_penalty = kwargs.pop("repetition_penalty", 1.0)
self.length_penalty = kwargs.pop("length_penalty", 1.0) self.length_penalty = kwargs.pop("length_penalty", 1.0)
self.no_repeat_ngram_size = kwargs.pop("no_repeat_ngram_size", 0) self.no_repeat_ngram_size = kwargs.pop("no_repeat_ngram_size", 0)
self.bad_words_ids = kwargs.pop("bad_words_ids", None)
self.num_return_sequences = kwargs.pop("num_return_sequences", 1) self.num_return_sequences = kwargs.pop("num_return_sequences", 1)
# Fine-tuning task arguments # Fine-tuning task arguments
......
...@@ -467,6 +467,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): ...@@ -467,6 +467,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
top_k=None, top_k=None,
top_p=None, top_p=None,
repetition_penalty=None, repetition_penalty=None,
bad_words_ids=None,
bos_token_id=None, bos_token_id=None,
pad_token_id=None, pad_token_id=None,
eos_token_id=None, eos_token_id=None,
...@@ -532,6 +533,9 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): ...@@ -532,6 +533,9 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
no_repeat_ngram_size: (`optional`) int no_repeat_ngram_size: (`optional`) int
If set to int > 0, all ngrams of size `no_repeat_ngram_size` can only occur once. If set to int > 0, all ngrams of size `no_repeat_ngram_size` can only occur once.
bad_words_ids: (`optional`) list of lists of int
`bad_words_ids` contains tokens that are not allowed to be generated. In order to get the tokens of the words that should not appear in the generated text, use `tokenizer.encode(bad_word, add_prefix_space=True)`.
num_return_sequences: (`optional`) int num_return_sequences: (`optional`) int
The number of independently computed returned sequences for each element in the batch. Default to 1. The number of independently computed returned sequences for each element in the batch. Default to 1.
...@@ -582,6 +586,12 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): ...@@ -582,6 +586,12 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
outputs = model.generate(input_ids=input_ids, max_length=50, temperature=0.7, repetition_penalty=1.2) # generate sequences outputs = model.generate(input_ids=input_ids, max_length=50, temperature=0.7, repetition_penalty=1.2) # generate sequences
print('Generated: {}'.format(tokenizer.decode(outputs[0], skip_special_tokens=True))) print('Generated: {}'.format(tokenizer.decode(outputs[0], skip_special_tokens=True)))
tokenizer = AutoTokenizer.from_pretrained('gpt2') # Initialize tokenizer
model = TFAutoModelWithLMHead.from_pretrained('gpt2') # Download model and configuration from S3 and cache.
input_context = 'My cute dog' # "Legal" is one of the control codes for ctrl
bad_words_ids = [tokenizer.encode(bad_word, add_prefix_space=True) for bad_word in ['idiot', 'stupid', 'shut up']]
input_ids = tokenizer.encode(input_context, return_tensors='tf') # encode input context
outputs = model.generate(input_ids=input_ids, max_length=100, do_sample=True, bad_words_ids=bad_words_ids) # generate sequences without allowing bad_words to be generated
""" """
# We cannot generate if the model does not have a LM head # We cannot generate if the model does not have a LM head
...@@ -607,6 +617,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): ...@@ -607,6 +617,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
no_repeat_ngram_size = ( no_repeat_ngram_size = (
no_repeat_ngram_size if no_repeat_ngram_size is not None else self.config.no_repeat_ngram_size no_repeat_ngram_size if no_repeat_ngram_size is not None else self.config.no_repeat_ngram_size
) )
bad_words_ids = bad_words_ids if bad_words_ids is not None else self.config.bad_words_ids
num_return_sequences = ( num_return_sequences = (
num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences
) )
...@@ -641,6 +652,9 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): ...@@ -641,6 +652,9 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
assert ( assert (
isinstance(num_return_sequences, int) and num_return_sequences > 0 isinstance(num_return_sequences, int) and num_return_sequences > 0
), "`num_return_sequences` should be a strictely positive integer." ), "`num_return_sequences` should be a strictely positive integer."
assert (
bad_words_ids is None or isinstance(bad_words_ids, list) and isinstance(bad_words_ids[0], list)
), "`bad_words_ids` is either `None` or a list of lists of tokens that should not be generated"
if input_ids is None: if input_ids is None:
assert isinstance(bos_token_id, int) and bos_token_id >= 0, ( assert isinstance(bos_token_id, int) and bos_token_id >= 0, (
...@@ -742,6 +756,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): ...@@ -742,6 +756,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
top_p=top_p, top_p=top_p,
repetition_penalty=repetition_penalty, repetition_penalty=repetition_penalty,
no_repeat_ngram_size=no_repeat_ngram_size, no_repeat_ngram_size=no_repeat_ngram_size,
bad_words_ids=bad_words_ids,
bos_token_id=bos_token_id, bos_token_id=bos_token_id,
pad_token_id=pad_token_id, pad_token_id=pad_token_id,
eos_token_id=eos_token_id, eos_token_id=eos_token_id,
...@@ -766,6 +781,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): ...@@ -766,6 +781,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
top_p=top_p, top_p=top_p,
repetition_penalty=repetition_penalty, repetition_penalty=repetition_penalty,
no_repeat_ngram_size=no_repeat_ngram_size, no_repeat_ngram_size=no_repeat_ngram_size,
bad_words_ids=bad_words_ids,
bos_token_id=bos_token_id, bos_token_id=bos_token_id,
pad_token_id=pad_token_id, pad_token_id=pad_token_id,
eos_token_id=eos_token_id, eos_token_id=eos_token_id,
...@@ -790,6 +806,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): ...@@ -790,6 +806,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
top_p, top_p,
repetition_penalty, repetition_penalty,
no_repeat_ngram_size, no_repeat_ngram_size,
bad_words_ids,
bos_token_id, bos_token_id,
pad_token_id, pad_token_id,
eos_token_id, eos_token_id,
...@@ -828,7 +845,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): ...@@ -828,7 +845,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
if no_repeat_ngram_size > 0: if no_repeat_ngram_size > 0:
# calculate a list of banned tokens to prevent repetitively generating the same ngrams # calculate a list of banned tokens to prevent repetitively generating the same ngrams
# from fairseq: https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345 # from fairseq: https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345
banned_tokens = calc_banned_tokens(input_ids, batch_size, no_repeat_ngram_size, cur_len) banned_tokens = calc_banned_ngram_tokens(input_ids, batch_size, no_repeat_ngram_size, cur_len)
# create banned_tokens boolean mask # create banned_tokens boolean mask
banned_tokens_indices_mask = [] banned_tokens_indices_mask = []
for banned_tokens_slice in banned_tokens: for banned_tokens_slice in banned_tokens:
...@@ -840,6 +857,20 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): ...@@ -840,6 +857,20 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
next_token_logits, tf.convert_to_tensor(banned_tokens_indices_mask, dtype=tf.bool), -float("inf") next_token_logits, tf.convert_to_tensor(banned_tokens_indices_mask, dtype=tf.bool), -float("inf")
) )
if bad_words_ids is not None:
# calculate a list of banned tokens according to bad words
banned_tokens = calc_banned_bad_words_ids(input_ids, bad_words_ids)
banned_tokens_indices_mask = []
for banned_tokens_slice in banned_tokens:
banned_tokens_indices_mask.append(
[True if token in banned_tokens_slice else False for token in range(vocab_size)]
)
next_token_logits = set_tensor_by_indices_to_value(
next_token_logits, tf.convert_to_tensor(banned_tokens_indices_mask, dtype=tf.bool), -float("inf")
)
# set eos token prob to zero if min_length is not reached # set eos token prob to zero if min_length is not reached
if eos_token_id is not None and cur_len < min_length: if eos_token_id is not None and cur_len < min_length:
# create eos_token_id boolean mask # create eos_token_id boolean mask
...@@ -936,6 +967,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): ...@@ -936,6 +967,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
top_p, top_p,
repetition_penalty, repetition_penalty,
no_repeat_ngram_size, no_repeat_ngram_size,
bad_words_ids,
bos_token_id, bos_token_id,
pad_token_id, pad_token_id,
decoder_start_token_id, decoder_start_token_id,
...@@ -1012,7 +1044,9 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): ...@@ -1012,7 +1044,9 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
# calculate a list of banned tokens to prevent repetitively generating the same ngrams # calculate a list of banned tokens to prevent repetitively generating the same ngrams
# from fairseq: https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345 # from fairseq: https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345
num_batch_hypotheses = batch_size * num_beams num_batch_hypotheses = batch_size * num_beams
banned_tokens = calc_banned_tokens(input_ids, num_batch_hypotheses, no_repeat_ngram_size, cur_len) banned_tokens = calc_banned_ngram_tokens(
input_ids, num_batch_hypotheses, no_repeat_ngram_size, cur_len
)
# create banned_tokens boolean mask # create banned_tokens boolean mask
banned_tokens_indices_mask = [] banned_tokens_indices_mask = []
for banned_tokens_slice in banned_tokens: for banned_tokens_slice in banned_tokens:
...@@ -1024,6 +1058,20 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): ...@@ -1024,6 +1058,20 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
scores, tf.convert_to_tensor(banned_tokens_indices_mask, dtype=tf.bool), -float("inf") scores, tf.convert_to_tensor(banned_tokens_indices_mask, dtype=tf.bool), -float("inf")
) )
if bad_words_ids is not None:
# calculate a list of banned tokens according to bad words
banned_tokens = calc_banned_bad_words_ids(input_ids, bad_words_ids)
banned_tokens_indices_mask = []
for banned_tokens_slice in banned_tokens:
banned_tokens_indices_mask.append(
[True if token in banned_tokens_slice else False for token in range(vocab_size)]
)
scores = set_tensor_by_indices_to_value(
scores, tf.convert_to_tensor(banned_tokens_indices_mask, dtype=tf.bool), -float("inf")
)
assert shape_list(scores) == [batch_size * num_beams, vocab_size] assert shape_list(scores) == [batch_size * num_beams, vocab_size]
if do_sample: if do_sample:
...@@ -1243,7 +1291,7 @@ def _create_next_token_logits_penalties(input_ids, logits, repetition_penalty): ...@@ -1243,7 +1291,7 @@ def _create_next_token_logits_penalties(input_ids, logits, repetition_penalty):
return tf.convert_to_tensor(token_penalties, dtype=tf.float32) return tf.convert_to_tensor(token_penalties, dtype=tf.float32)
def calc_banned_tokens(prev_input_ids, num_hypos, no_repeat_ngram_size, cur_len): def calc_banned_ngram_tokens(prev_input_ids, num_hypos, no_repeat_ngram_size, cur_len):
# Copied from fairseq for no_repeat_ngram in beam_search""" # Copied from fairseq for no_repeat_ngram in beam_search"""
if cur_len + 1 < no_repeat_ngram_size: if cur_len + 1 < no_repeat_ngram_size:
# return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet # return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet
...@@ -1266,6 +1314,42 @@ def calc_banned_tokens(prev_input_ids, num_hypos, no_repeat_ngram_size, cur_len) ...@@ -1266,6 +1314,42 @@ def calc_banned_tokens(prev_input_ids, num_hypos, no_repeat_ngram_size, cur_len)
return banned_tokens return banned_tokens
def calc_banned_bad_words_ids(prev_input_ids, bad_words_ids):
banned_tokens = []
def _tokens_match(prev_tokens, tokens):
if len(tokens) == 0:
# if bad word tokens is just one token always ban it
return True
if len(tokens) > len(prev_input_ids):
# if bad word tokens are longer then prev input_ids they can't be equal
return False
if prev_tokens[-len(tokens) :] == tokens:
# if tokens match
return True
else:
return False
for prev_input_ids_slice in prev_input_ids:
banned_tokens_slice = []
for banned_token_seq in bad_words_ids:
assert len(banned_token_seq) > 0, "Banned words token sequences {} cannot have an empty list".format(
bad_words_ids
)
if _tokens_match(prev_input_ids_slice.numpy().tolist(), banned_token_seq[:-1]) is False:
# if tokens do not match continue
continue
banned_tokens_slice.append(banned_token_seq[-1])
banned_tokens.append(banned_tokens_slice)
return banned_tokens
def tf_top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1): def tf_top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1):
""" Filter a distribution of logits using top-k and/or nucleus (top-p) filtering """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
Args: Args:
......
...@@ -667,6 +667,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ...@@ -667,6 +667,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
top_k=None, top_k=None,
top_p=None, top_p=None,
repetition_penalty=None, repetition_penalty=None,
bad_words_ids=None,
bos_token_id=None, bos_token_id=None,
pad_token_id=None, pad_token_id=None,
eos_token_id=None, eos_token_id=None,
...@@ -731,6 +732,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ...@@ -731,6 +732,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
no_repeat_ngram_size: (`optional`) int no_repeat_ngram_size: (`optional`) int
If set to int > 0, all ngrams of size `no_repeat_ngram_size` can only occur once. If set to int > 0, all ngrams of size `no_repeat_ngram_size` can only occur once.
bad_words_ids: (`optional`) list of lists of int
`bad_words_ids` contains tokens that are not allowed to be generated. In order to get the tokens of the words that should not appear in the generated text, use `tokenizer.encode(bad_word, add_prefix_space=True)`.
num_return_sequences: (`optional`) int num_return_sequences: (`optional`) int
The number of independently computed returned sequences for each element in the batch. Default to 1. The number of independently computed returned sequences for each element in the batch. Default to 1.
...@@ -782,6 +785,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ...@@ -782,6 +785,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
outputs = model.generate(input_ids=input_ids, max_length=50, temperature=0.7, repetition_penalty=1.2) # generate sequences outputs = model.generate(input_ids=input_ids, max_length=50, temperature=0.7, repetition_penalty=1.2) # generate sequences
print('Generated: {}'.format(tokenizer.decode(outputs[0], skip_special_tokens=True))) print('Generated: {}'.format(tokenizer.decode(outputs[0], skip_special_tokens=True)))
tokenizer = AutoTokenizer.from_pretrained('gpt2') # Initialize tokenizer
model = AutoModelWithLMHead.from_pretrained('gpt2') # Download model and configuration from S3 and cache.
input_context = 'My cute dog' # "Legal" is one of the control codes for ctrl
bad_words_ids = [tokenizer.encode(bad_word, add_prefix_space=True) for bad_word in ['idiot', 'stupid', 'shut up']]
input_ids = tokenizer.encode(input_context, return_tensors='pt') # encode input context
outputs = model.generate(input_ids=input_ids, max_length=100, do_sample=True, bad_words_ids=bad_words_ids) # generate sequences without allowing bad_words to be generated
""" """
# We cannot generate if the model does not have a LM head # We cannot generate if the model does not have a LM head
...@@ -807,6 +816,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ...@@ -807,6 +816,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
no_repeat_ngram_size = ( no_repeat_ngram_size = (
no_repeat_ngram_size if no_repeat_ngram_size is not None else self.config.no_repeat_ngram_size no_repeat_ngram_size if no_repeat_ngram_size is not None else self.config.no_repeat_ngram_size
) )
bad_words_ids = bad_words_ids if bad_words_ids is not None else self.config.bad_words_ids
num_return_sequences = ( num_return_sequences = (
num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences
) )
...@@ -844,6 +854,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ...@@ -844,6 +854,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
assert ( assert (
isinstance(num_return_sequences, int) and num_return_sequences > 0 isinstance(num_return_sequences, int) and num_return_sequences > 0
), "`num_return_sequences` should be a strictly positive integer." ), "`num_return_sequences` should be a strictly positive integer."
assert (
bad_words_ids is None or isinstance(bad_words_ids, list) and isinstance(bad_words_ids[0], list)
), "`bad_words_ids` is either `None` or a list of lists of tokens that should not be generated"
if input_ids is None: if input_ids is None:
assert isinstance(bos_token_id, int) and bos_token_id >= 0, ( assert isinstance(bos_token_id, int) and bos_token_id >= 0, (
...@@ -964,6 +977,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ...@@ -964,6 +977,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
top_p=top_p, top_p=top_p,
repetition_penalty=repetition_penalty, repetition_penalty=repetition_penalty,
no_repeat_ngram_size=no_repeat_ngram_size, no_repeat_ngram_size=no_repeat_ngram_size,
bad_words_ids=bad_words_ids,
bos_token_id=bos_token_id, bos_token_id=bos_token_id,
pad_token_id=pad_token_id, pad_token_id=pad_token_id,
decoder_start_token_id=decoder_start_token_id, decoder_start_token_id=decoder_start_token_id,
...@@ -988,6 +1002,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ...@@ -988,6 +1002,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
top_p=top_p, top_p=top_p,
repetition_penalty=repetition_penalty, repetition_penalty=repetition_penalty,
no_repeat_ngram_size=no_repeat_ngram_size, no_repeat_ngram_size=no_repeat_ngram_size,
bad_words_ids=bad_words_ids,
bos_token_id=bos_token_id, bos_token_id=bos_token_id,
pad_token_id=pad_token_id, pad_token_id=pad_token_id,
decoder_start_token_id=decoder_start_token_id, decoder_start_token_id=decoder_start_token_id,
...@@ -1011,6 +1026,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ...@@ -1011,6 +1026,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
top_p, top_p,
repetition_penalty, repetition_penalty,
no_repeat_ngram_size, no_repeat_ngram_size,
bad_words_ids,
bos_token_id, bos_token_id,
pad_token_id, pad_token_id,
eos_token_id, eos_token_id,
...@@ -1045,7 +1061,14 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ...@@ -1045,7 +1061,14 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
if no_repeat_ngram_size > 0: if no_repeat_ngram_size > 0:
# calculate a list of banned tokens to prevent repetitively generating the same ngrams # calculate a list of banned tokens to prevent repetitively generating the same ngrams
# from fairseq: https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345 # from fairseq: https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345
banned_tokens = calc_banned_tokens(input_ids, batch_size, no_repeat_ngram_size, cur_len) banned_tokens = calc_banned_ngram_tokens(input_ids, batch_size, no_repeat_ngram_size, cur_len)
for batch_idx in range(batch_size):
next_token_logits[batch_idx, banned_tokens[batch_idx]] = -float("inf")
if bad_words_ids is not None:
# calculate a list of banned tokens according to bad words
banned_tokens = calc_banned_bad_words_ids(input_ids, bad_words_ids)
for batch_idx in range(batch_size): for batch_idx in range(batch_size):
next_token_logits[batch_idx, banned_tokens[batch_idx]] = -float("inf") next_token_logits[batch_idx, banned_tokens[batch_idx]] = -float("inf")
...@@ -1121,6 +1144,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ...@@ -1121,6 +1144,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
top_p, top_p,
repetition_penalty, repetition_penalty,
no_repeat_ngram_size, no_repeat_ngram_size,
bad_words_ids,
bos_token_id, bos_token_id,
pad_token_id, pad_token_id,
eos_token_id, eos_token_id,
...@@ -1187,12 +1211,19 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ...@@ -1187,12 +1211,19 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
# calculate a list of banned tokens to prevent repetitively generating the same ngrams # calculate a list of banned tokens to prevent repetitively generating the same ngrams
num_batch_hypotheses = batch_size * num_beams num_batch_hypotheses = batch_size * num_beams
# from fairseq: https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345 # from fairseq: https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345
banned_batch_tokens = calc_banned_tokens( banned_batch_tokens = calc_banned_ngram_tokens(
input_ids, num_batch_hypotheses, no_repeat_ngram_size, cur_len input_ids, num_batch_hypotheses, no_repeat_ngram_size, cur_len
) )
for i, banned_tokens in enumerate(banned_batch_tokens): for i, banned_tokens in enumerate(banned_batch_tokens):
scores[i, banned_tokens] = -float("inf") scores[i, banned_tokens] = -float("inf")
if bad_words_ids is not None:
# calculate a list of banned tokens according to bad words
banned_tokens = calc_banned_bad_words_ids(input_ids, bad_words_ids)
for i, banned_tokens in enumerate(banned_tokens):
scores[i, banned_tokens] = -float("inf")
assert scores.shape == (batch_size * num_beams, vocab_size), "Shapes of scores: {} != {}".format( assert scores.shape == (batch_size * num_beams, vocab_size), "Shapes of scores: {} != {}".format(
scores.shape, (batch_size * num_beams, vocab_size) scores.shape, (batch_size * num_beams, vocab_size)
) )
...@@ -1397,7 +1428,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ...@@ -1397,7 +1428,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
return past return past
def calc_banned_tokens(prev_input_ids, num_hypos, no_repeat_ngram_size, cur_len): def calc_banned_ngram_tokens(prev_input_ids, num_hypos, no_repeat_ngram_size, cur_len):
# Copied from fairseq for no_repeat_ngram in beam_search""" # Copied from fairseq for no_repeat_ngram in beam_search"""
if cur_len + 1 < no_repeat_ngram_size: if cur_len + 1 < no_repeat_ngram_size:
# return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet # return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet
...@@ -1420,6 +1451,42 @@ def calc_banned_tokens(prev_input_ids, num_hypos, no_repeat_ngram_size, cur_len) ...@@ -1420,6 +1451,42 @@ def calc_banned_tokens(prev_input_ids, num_hypos, no_repeat_ngram_size, cur_len)
return banned_tokens return banned_tokens
def calc_banned_bad_words_ids(prev_input_ids, bad_words_ids):
banned_tokens = []
def _tokens_match(prev_tokens, tokens):
if len(tokens) == 0:
# if bad word tokens is just one token always ban it
return True
if len(tokens) > len(prev_input_ids):
# if bad word tokens are longer then prev input_ids they can't be equal
return False
if prev_tokens[-len(tokens) :] == tokens:
# if tokens match
return True
else:
return False
for prev_input_ids_slice in prev_input_ids:
banned_tokens_slice = []
for banned_token_seq in bad_words_ids:
assert len(banned_token_seq) > 0, "Banned words token sequences {} cannot have an empty list".format(
bad_words_ids
)
if _tokens_match(prev_input_ids_slice.tolist(), banned_token_seq[:-1]) is False:
# if tokens do not match continue
continue
banned_tokens_slice.append(banned_token_seq[-1])
banned_tokens.append(banned_tokens_slice)
return banned_tokens
def top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1): def top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1):
""" Filter a distribution of logits using top-k and/or nucleus (top-p) filtering """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
Args: Args:
......
...@@ -641,14 +641,14 @@ class ModelTesterMixin: ...@@ -641,14 +641,14 @@ class ModelTesterMixin:
with self.assertRaises(AssertionError): with self.assertRaises(AssertionError):
model.generate(do_sample=True, max_length=5) model.generate(do_sample=True, max_length=5)
# batch_size = 1 # batch_size = 1
self._check_generated_tokens(model.generate(input_ids, do_sample=True)) self._check_generated_ids(model.generate(input_ids, do_sample=True))
# batch_size = 1, num_beams > 1 # batch_size = 1, num_beams > 1
self._check_generated_tokens(model.generate(input_ids, do_sample=True, num_beams=3)) self._check_generated_ids(model.generate(input_ids, do_sample=True, num_beams=3))
else: else:
# batch_size = 1 # batch_size = 1
self._check_generated_tokens(model.generate(do_sample=True, max_length=5)) self._check_generated_ids(model.generate(do_sample=True, max_length=5))
# batch_size = 1, num_beams > 1 # batch_size = 1, num_beams > 1
self._check_generated_tokens(model.generate(do_sample=True, max_length=5, num_beams=3)) self._check_generated_ids(model.generate(do_sample=True, max_length=5, num_beams=3))
with self.assertRaises(AssertionError): with self.assertRaises(AssertionError):
# generating multiple sequences when greedy no beam generation # generating multiple sequences when greedy no beam generation
...@@ -660,24 +660,52 @@ class ModelTesterMixin: ...@@ -660,24 +660,52 @@ class ModelTesterMixin:
model.generate(input_ids, do_sample=False, num_return_sequences=3, num_beams=2) model.generate(input_ids, do_sample=False, num_return_sequences=3, num_beams=2)
# batch_size > 1, sample # batch_size > 1, sample
self._check_generated_tokens(model.generate(input_ids, do_sample=True, num_return_sequences=3)) self._check_generated_ids(model.generate(input_ids, do_sample=True, num_return_sequences=3))
# batch_size > 1, greedy # batch_size > 1, greedy
self._check_generated_tokens(model.generate(input_ids, do_sample=False)) self._check_generated_ids(model.generate(input_ids, do_sample=False))
# batch_size > 1, num_beams > 1, sample # batch_size > 1, num_beams > 1, sample
self._check_generated_tokens( self._check_generated_ids(model.generate(input_ids, do_sample=True, num_beams=3, num_return_sequences=3,))
model.generate(input_ids, do_sample=True, num_beams=3, num_return_sequences=3,)
)
# batch_size > 1, num_beams > 1, greedy # batch_size > 1, num_beams > 1, greedy
self._check_generated_tokens( self._check_generated_ids(model.generate(input_ids, do_sample=False, num_beams=3, num_return_sequences=3))
model.generate(input_ids, do_sample=False, num_beams=3, num_return_sequences=3)
# check bad words tokens language generation
bad_words_ids = [
ids_tensor((1, 1), self.model_tester.vocab_size).squeeze(-1).tolist(),
ids_tensor((2, 1), self.model_tester.vocab_size).squeeze(-1).tolist(),
]
# sampling
output_tokens = model.generate(
input_ids, do_sample=True, bad_words_ids=bad_words_ids, num_return_sequences=3
) )
generated_ids = output_tokens[:, input_ids.shape[-1] :]
self.assertFalse(self._check_match_tokens(generated_ids.tolist(), bad_words_ids))
def _check_generated_tokens(self, output_ids): # beam search
output_tokens = model.generate(
input_ids, do_sample=False, bad_words_ids=bad_words_ids, num_beams=3, num_return_sequences=3
)
generated_ids = output_tokens[:, input_ids.shape[-1] :]
self.assertFalse(self._check_match_tokens(generated_ids.tolist(), bad_words_ids))
def _check_generated_ids(self, output_ids):
for token_id in output_ids[0].tolist(): for token_id in output_ids[0].tolist():
self.assertGreaterEqual(token_id, 0) self.assertGreaterEqual(token_id, 0)
self.assertLess(token_id, self.model_tester.vocab_size) self.assertLess(token_id, self.model_tester.vocab_size)
def _check_match_tokens(self, generated_ids, bad_words_ids):
# for all bad word tokens
for bad_word_ids in bad_words_ids:
# for all slices in batch
for generated_ids_slice in generated_ids:
# for all word idx
for i in range(len(bad_word_ids), len(generated_ids_slice)):
# if tokens match
if generated_ids_slice[i - len(bad_word_ids) : i] == bad_word_ids:
return True
return False
global_rng = random.Random() global_rng = random.Random()
......
...@@ -427,14 +427,14 @@ class TFModelTesterMixin: ...@@ -427,14 +427,14 @@ class TFModelTesterMixin:
with self.assertRaises(AssertionError): with self.assertRaises(AssertionError):
model.generate(do_sample=True, max_length=5) model.generate(do_sample=True, max_length=5)
# batch_size = 1 # batch_size = 1
self._check_generated_tokens(model.generate(input_ids, do_sample=True)) self._check_generated_ids(model.generate(input_ids, do_sample=True))
# batch_size = 1, num_beams > 1 # batch_size = 1, num_beams > 1
self._check_generated_tokens(model.generate(input_ids, do_sample=True, num_beams=3)) self._check_generated_ids(model.generate(input_ids, do_sample=True, num_beams=3))
else: else:
# batch_size = 1 # batch_size = 1
self._check_generated_tokens(model.generate(do_sample=True, max_length=5)) self._check_generated_ids(model.generate(do_sample=True, max_length=5))
# batch_size = 1, num_beams > 1 # batch_size = 1, num_beams > 1
self._check_generated_tokens(model.generate(do_sample=True, max_length=5, num_beams=3)) self._check_generated_ids(model.generate(do_sample=True, max_length=5, num_beams=3))
with self.assertRaises(AssertionError): with self.assertRaises(AssertionError):
# generating multiple sequences when greedy no beam generation # generating multiple sequences when greedy no beam generation
...@@ -446,24 +446,52 @@ class TFModelTesterMixin: ...@@ -446,24 +446,52 @@ class TFModelTesterMixin:
model.generate(input_ids, do_sample=False, num_return_sequences=3, num_beams=2) model.generate(input_ids, do_sample=False, num_return_sequences=3, num_beams=2)
# batch_size > 1, sample # batch_size > 1, sample
self._check_generated_tokens(model.generate(input_ids, do_sample=True, num_return_sequences=3)) self._check_generated_ids(model.generate(input_ids, do_sample=True, num_return_sequences=3))
# batch_size > 1, greedy # batch_size > 1, greedy
self._check_generated_tokens(model.generate(input_ids, do_sample=False)) self._check_generated_ids(model.generate(input_ids, do_sample=False))
# batch_size > 1, num_beams > 1, sample # batch_size > 1, num_beams > 1, sample
self._check_generated_tokens( self._check_generated_ids(model.generate(input_ids, do_sample=True, num_beams=3, num_return_sequences=3,))
model.generate(input_ids, do_sample=True, num_beams=3, num_return_sequences=3,)
)
# batch_size > 1, num_beams > 1, greedy # batch_size > 1, num_beams > 1, greedy
self._check_generated_tokens( self._check_generated_ids(model.generate(input_ids, do_sample=False, num_beams=3, num_return_sequences=3))
model.generate(input_ids, do_sample=False, num_beams=3, num_return_sequences=3)
# check bad words tokens language generation
bad_words_ids = [
tf.squeeze(ids_tensor((1, 1), self.model_tester.vocab_size), -1).numpy().tolist(),
tf.squeeze(ids_tensor((2, 1), self.model_tester.vocab_size), -1).numpy().tolist(),
]
# sampling
output_tokens = model.generate(
input_ids, do_sample=True, bad_words_ids=bad_words_ids, num_return_sequences=3
) )
generated_ids = output_tokens[:, input_ids.shape[-1] :]
self.assertFalse(self._check_match_tokens(generated_ids.numpy().tolist(), bad_words_ids))
def _check_generated_tokens(self, output_ids): # beam search
output_tokens = model.generate(
input_ids, do_sample=False, bad_words_ids=bad_words_ids, num_beams=3, num_return_sequences=3
)
generated_ids = output_tokens[:, input_ids.shape[-1] :]
self.assertFalse(self._check_match_tokens(generated_ids.numpy().tolist(), bad_words_ids))
def _check_generated_ids(self, output_ids):
for token_id in output_ids[0].numpy().tolist(): for token_id in output_ids[0].numpy().tolist():
self.assertGreaterEqual(token_id, 0) self.assertGreaterEqual(token_id, 0)
self.assertLess(token_id, self.model_tester.vocab_size) self.assertLess(token_id, self.model_tester.vocab_size)
def _check_match_tokens(self, generated_ids, bad_words_ids):
# for all bad word tokens
for bad_word_ids in bad_words_ids:
# for all slices in batch
for generated_ids_slice in generated_ids:
# for all word idx
for i in range(len(bad_word_ids), len(generated_ids_slice)):
# if tokens match
if generated_ids_slice[i - len(bad_word_ids) : i] == bad_word_ids:
return True
return False
def ids_tensor(shape, vocab_size, rng=None, name=None, dtype=None): def ids_tensor(shape, vocab_size, rng=None, name=None, dtype=None):
"""Creates a random int32 tensor of the shape within the vocab size.""" """Creates a random int32 tensor of the shape within the vocab size."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment