Unverified Commit 676247fd authored by Rishab26's avatar Rishab26 Committed by GitHub
Browse files

[DOCS] Add `NoRepeatNGramLogitsProcessor` Example for `LogitsProcessor` class (#25186)

* Add Description And Example to Docstring

* make style corrections

* make style

* Doc Style Consistent With HF

* Apply make style

* Modify Docstring

* Edit Type in Docstring

* Feedback Incorporated

* Edit Docstring

* make style

* Post Review Changes

* Review Feedback Incorporated

* Styling

* Formatting

* make style

* pep8
parent 5fe36970
...@@ -577,10 +577,28 @@ class EtaLogitsWarper(LogitsWarper): ...@@ -577,10 +577,28 @@ class EtaLogitsWarper(LogitsWarper):
def _get_ngrams(ngram_size: int, prev_input_ids: torch.Tensor, num_hypos: int): def _get_ngrams(ngram_size: int, prev_input_ids: torch.Tensor, num_hypos: int):
"""
Assume ngram_size=2 and prev_input_ids=tensor([[40, 2883, 2712, 4346]]). The output of generated ngrams look like
this {(40,): [2883], (2883,): [2712], (2712,): [4346]}.
Args:
ngram_size (`int`):
The number sequential tokens taken as a group which may only occur once before being banned.
prev_input_ids (`torch.Tensor`):
Generated token ids for the current hypothesis.
num_hypos (`int`):
The number of hypotheses for which n-grams need to be generated.
Returns:
generated_ngrams (`dict`):
Dictionary of generated ngrams.
"""
# Initialize an empty list of dictionaries, one for each hypothesis (index) in the range of num_hypos
generated_ngrams = [{} for _ in range(num_hypos)] generated_ngrams = [{} for _ in range(num_hypos)]
for idx in range(num_hypos): for idx in range(num_hypos):
gen_tokens = prev_input_ids[idx].tolist() gen_tokens = prev_input_ids[idx].tolist()
generated_ngram = generated_ngrams[idx] generated_ngram = generated_ngrams[idx]
# Loop through each n-gram of size ngram_size in the list of tokens (gen_tokens)
for ngram in zip(*[gen_tokens[i:] for i in range(ngram_size)]): for ngram in zip(*[gen_tokens[i:] for i in range(ngram_size)]):
prev_ngram_tuple = tuple(ngram[:-1]) prev_ngram_tuple = tuple(ngram[:-1])
generated_ngram[prev_ngram_tuple] = generated_ngram.get(prev_ngram_tuple, []) + [ngram[-1]] generated_ngram[prev_ngram_tuple] = generated_ngram.get(prev_ngram_tuple, []) + [ngram[-1]]
...@@ -588,6 +606,22 @@ def _get_ngrams(ngram_size: int, prev_input_ids: torch.Tensor, num_hypos: int): ...@@ -588,6 +606,22 @@ def _get_ngrams(ngram_size: int, prev_input_ids: torch.Tensor, num_hypos: int):
def _get_generated_ngrams(banned_ngrams, prev_input_ids, ngram_size, cur_len): def _get_generated_ngrams(banned_ngrams, prev_input_ids, ngram_size, cur_len):
"""
Determines the banned tokens for the current hypothesis based on previously generated n-grams.
Args:
banned_ngrams (`dict`):
A dictionary containing previously generated n-grams for each hypothesis.
prev_input_ids (`torch.Tensor`):
Generated token ids for the current hypothesis.
ngram_size (`int`):
The number sequential tokens taken as a group which may only occur once before being banned.
cur_len (`int`):
The current length of the token sequences for which the n-grams are being checked.
Returns:
List of tokens that are banned.
"""
# Before decoding the next token, prevent decoding of ngrams that have already appeared # Before decoding the next token, prevent decoding of ngrams that have already appeared
start_idx = cur_len + 1 - ngram_size start_idx = cur_len + 1 - ngram_size
ngram_idx = tuple(prev_input_ids[start_idx:cur_len].tolist()) ngram_idx = tuple(prev_input_ids[start_idx:cur_len].tolist())
...@@ -601,9 +635,7 @@ def _calc_banned_ngram_tokens( ...@@ -601,9 +635,7 @@ def _calc_banned_ngram_tokens(
if cur_len + 1 < ngram_size: if cur_len + 1 < ngram_size:
# return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet # return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet
return [[] for _ in range(num_hypos)] return [[] for _ in range(num_hypos)]
generated_ngrams = _get_ngrams(ngram_size, prev_input_ids, num_hypos) generated_ngrams = _get_ngrams(ngram_size, prev_input_ids, num_hypos)
banned_tokens = [ banned_tokens = [
_get_generated_ngrams(generated_ngrams[hypo_idx], prev_input_ids[hypo_idx], ngram_size, cur_len) _get_generated_ngrams(generated_ngrams[hypo_idx], prev_input_ids[hypo_idx], ngram_size, cur_len)
for hypo_idx in range(num_hypos) for hypo_idx in range(num_hypos)
...@@ -613,12 +645,43 @@ def _calc_banned_ngram_tokens( ...@@ -613,12 +645,43 @@ def _calc_banned_ngram_tokens(
class NoRepeatNGramLogitsProcessor(LogitsProcessor): class NoRepeatNGramLogitsProcessor(LogitsProcessor):
r""" r"""
[`LogitsProcessor`] that enforces no repetition of n-grams. See N-grams are groups of "n" consecutive words, characters, or tokens taken from a sequence of text. Given the
sentence: "She runs fast", the bi-grams (n=2) would be ("she", "runs") and ("runs", "fast"). In text generation,
avoiding repetitions of word sequences provides a more diverse output. This [`LogitsProcessor`] enforces no
repetition of n-grams by setting the scores of banned tokens to negative infinity which eliminates those tokens
from consideration when further processing the scores.
[Fairseq](https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345). [Fairseq](https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345).
<Tip>
Use n-gram penalties with care. For instance, penalizing 2-grams (bigrams) in an article about the city of New York
might lead to undesirable outcomes where the city's name appears only once in the entire text.
[Reference](https://huggingface.co/blog/how-to-generate)
</Tip>
Args: Args:
ngram_size (`int`): ngram_size (`int`):
All ngrams of size `ngram_size` can only occur once. All ngrams of size `ngram_size` can only occur once.
Examples:
```py
>>> from transformers import GPT2Tokenizer, AutoModelForCausalLM
>>> model = AutoModelForCausalLM.from_pretrained("gpt2")
>>> tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
>>> inputs = tokenizer(["I enjoy watching football"], return_tensors="pt")
>>> output = model.generate(**inputs, max_length=50)
>>> print(tokenizer.decode(output[0], skip_special_tokens=True))
"I enjoy playing football on the weekends, but I'm not a big fan of the idea of playing in the middle of the night. I'm not a big fan of the idea of playing in the middle of the night. I'm not a big"
>>> # Now let's add ngram size using <no_repeat_ngram_size> in model.generate. This should stop the repetitions in the output.
>>> output = model.generate(**inputs, max_length=50, no_repeat_ngram_size=2)
>>> print(tokenizer.decode(output[0], skip_special_tokens=True))
I enjoy playing football on the weekends, but I'm not a big fan of the idea of playing in the middle of a game. I think it's a bit of an overreaction to the fact that we're playing a team that's playing"
```
""" """
def __init__(self, ngram_size: int): def __init__(self, ngram_size: int):
...@@ -631,7 +694,6 @@ class NoRepeatNGramLogitsProcessor(LogitsProcessor): ...@@ -631,7 +694,6 @@ class NoRepeatNGramLogitsProcessor(LogitsProcessor):
num_batch_hypotheses = scores.shape[0] num_batch_hypotheses = scores.shape[0]
cur_len = input_ids.shape[-1] cur_len = input_ids.shape[-1]
banned_batch_tokens = _calc_banned_ngram_tokens(self.ngram_size, input_ids, num_batch_hypotheses, cur_len) banned_batch_tokens = _calc_banned_ngram_tokens(self.ngram_size, input_ids, num_batch_hypotheses, cur_len)
for i, banned_tokens in enumerate(banned_batch_tokens): for i, banned_tokens in enumerate(banned_batch_tokens):
scores[i, banned_tokens] = -float("inf") scores[i, banned_tokens] = -float("inf")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment