Unverified Commit 45da7cec authored by Motoki Wu's avatar Motoki Wu Committed by GitHub
Browse files

Add custom stop token ids for generation (#20727)

* Add StopIdStoppingCriteria

* add a working test for stop id criteria

* add to global scope

* add stop_ids to generate

* add pipeline test

* use tokenizer encode in test

* add test to generation utils

* reformat

* fixup

* make-fix-copies

* rename to stop_token_id

* use stop_tokens instead

* add to text to text generation

* make fixup

* make repo-consistency

* Add support for list of ints for eos_token_id inside generation/utils.py

* Instead of having if elses, cast the eos_token_id into a List[int]

* Add List[int] support for logits_process.py

* add List[int] for beam_search.py

* add List[int] for forced_eos_token_id

* revert stop token id stopping criteria changes

* make fixup

* fix tests

* add eos_token_id to generation/utils.py and added tests test_utils.py

* add eos_token_id type hints and fix for pad tokens

* add comments

* remove some prints and remove forced false test

* fix

* put back test_stop_sequence_stopping_criteria

* remove unused import and make fixup

* add a none check

* update docstring

* add more docstring for list ints

* make fixup
parent cd918492
......@@ -16,7 +16,7 @@
import warnings
from abc import ABC, abstractmethod
from collections import UserDict
from typing import List, Optional, Tuple
from typing import List, Optional, Tuple, Union
import numpy as np
import torch
......@@ -42,8 +42,8 @@ PROCESS_INPUTS_DOCSTRING = r"""
Beam indices indicating to which beam hypothesis the `next_tokens` correspond.
pad_token_id (`int`, *optional*):
The id of the *padding* token.
eos_token_id (`int`, *optional*):
The id of the *end-of-sequence* token.
eos_token_id (`Union[int, List[int]]`, *optional*):
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
Return:
`UserDict`: A dictionary composed of the fields as defined above:
......@@ -74,8 +74,8 @@ FINALIZE_INPUTS_DOCSTRING = r"""
The beam indices indicating to which beam the `final_beam_tokens` shall be added.
pad_token_id (`int`, *optional*):
The id of the *padding* token.
eos_token_id (`int`, *optional*):
The id of the *end-of-sequence* token.
eos_token_id (`Union[int, List[int]]`, *optional*):
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
Return:
`torch.LongTensor` of shape `(batch_size * num_return_sequences, sequence_length)`: The generated sequences.
......@@ -212,7 +212,7 @@ class BeamSearchScorer(BeamScorer):
next_tokens: torch.LongTensor,
next_indices: torch.LongTensor,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
eos_token_id: Optional[Union[int, List[int]]] = None,
beam_indices: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor]:
cur_len = input_ids.shape[-1]
......@@ -234,6 +234,9 @@ class BeamSearchScorer(BeamScorer):
next_beam_tokens = torch.zeros((batch_size, self.group_size), dtype=next_tokens.dtype, device=device)
next_beam_indices = torch.zeros((batch_size, self.group_size), dtype=next_indices.dtype, device=device)
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
for batch_idx, beam_hyp in enumerate(self._beam_hyps):
if self._done[batch_idx]:
if self.num_beams < len(beam_hyp):
......@@ -253,7 +256,7 @@ class BeamSearchScorer(BeamScorer):
):
batch_beam_idx = batch_idx * self.group_size + next_index
# add to generated hypotheses if end of sentence
if (eos_token_id is not None) and (next_token.item() == eos_token_id):
if (eos_token_id is not None) and (next_token.item() in eos_token_id):
# if beam_token does not belong to top num_beams tokens, it should not be added
is_beam_token_worse_than_top_num_beams = beam_token_rank >= self.group_size
if is_beam_token_worse_than_top_num_beams:
......@@ -307,11 +310,14 @@ class BeamSearchScorer(BeamScorer):
final_beam_indices: torch.LongTensor,
max_length: int,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
eos_token_id: Optional[Union[int, List[int]]] = None,
beam_indices: Optional[torch.LongTensor] = None,
) -> Tuple[torch.LongTensor]:
batch_size = len(self._beam_hyps)
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
# finalize all open beam hypotheses and add to generated hypotheses
for batch_idx, beam_hyp in enumerate(self._beam_hyps):
if self._done[batch_idx]:
......@@ -376,7 +382,8 @@ class BeamSearchScorer(BeamScorer):
indices[i, : len(best_idx)] = torch.tensor(best_idx)
if sent_lengths[i] < sent_max_len:
decoded[i, sent_lengths[i]] = eos_token_id
# inserting only the first eos_token_id
decoded[i, sent_lengths[i]] = eos_token_id[0]
return UserDict(
{
......@@ -491,7 +498,7 @@ class ConstrainedBeamSearchScorer(BeamScorer):
next_indices: torch.LongTensor,
scores_for_all_vocab: torch.FloatTensor,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
eos_token_id: Optional[Union[int, List[int]]] = None,
) -> Tuple[torch.Tensor]:
r"""
Args:
......@@ -512,8 +519,8 @@ class ConstrainedBeamSearchScorer(BeamScorer):
The scores of all tokens in the vocabulary for each of the beam hypotheses.
pad_token_id (`int`, *optional*):
The id of the *padding* token.
eos_token_id (`int`, *optional*):
The id of the *end-of-sequence* token.
eos_token_id (`Union[int, List[int]]`, *optional*):
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
Return:
`UserDict`: A dictionary composed of the fields as defined above:
......@@ -549,6 +556,9 @@ class ConstrainedBeamSearchScorer(BeamScorer):
next_beam_tokens = torch.zeros((batch_size, self.group_size), dtype=next_tokens.dtype, device=device)
next_beam_indices = torch.zeros((batch_size, self.group_size), dtype=next_indices.dtype, device=device)
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
for batch_idx, beam_hyp in enumerate(self._beam_hyps):
if self._done[batch_idx]:
if self.num_beams < len(beam_hyp):
......@@ -568,7 +578,7 @@ class ConstrainedBeamSearchScorer(BeamScorer):
):
batch_beam_idx = batch_idx * self.group_size + next_index
# add to generated hypotheses if end of sentence
if (eos_token_id is not None) and (next_token.item() == eos_token_id):
if (eos_token_id is not None) and (next_token.item() in eos_token_id):
# if beam_token does not belong to top num_beams tokens, it should not be added
is_beam_token_worse_than_top_num_beams = beam_token_rank >= self.group_size
......@@ -773,10 +783,13 @@ class ConstrainedBeamSearchScorer(BeamScorer):
final_beam_indices: torch.LongTensor,
max_length: int,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
eos_token_id: Optional[Union[int, List[int]]] = None,
) -> Tuple[torch.LongTensor]:
batch_size = len(self._beam_hyps)
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
# finalize all open beam hypotheses and add to generated hypotheses
for batch_idx, beam_hyp in enumerate(self._beam_hyps):
if self._done[batch_idx]:
......@@ -840,7 +853,8 @@ class ConstrainedBeamSearchScorer(BeamScorer):
for i, hypo in enumerate(best):
decoded[i, : sent_lengths[i]] = hypo
if sent_lengths[i] < sent_max_len:
decoded[i, sent_lengths[i]] = eos_token_id
# inserting only the first eos_token_id
decoded[i, sent_lengths[i]] = eos_token_id[0]
return UserDict(
{
......
......@@ -142,8 +142,9 @@ class GenerationConfig(PushToHubMixin):
The id of the token to force as the first generated token after the `decoder_start_token_id`. Useful for
multilingual models like [mBART](../model_doc/mbart) where the first generated token needs to be the target
language token.
forced_eos_token_id (`int`, *optional*, defaults to `model.config.forced_eos_token_id`):
The id of the token to force as the last generated token when `max_length` is reached.
forced_eos_token_id (`Union[int, List[int]]`, *optional*, defaults to `model.config.forced_eos_token_id`):
The id of the token to force as the last generated token when `max_length` is reached. Optionally, use a
list to set multiple *end-of-sequence* tokens.
remove_invalid_values (`bool`, *optional*, defaults to `model.config.remove_invalid_values`):
Whether to remove possible *nan* and *inf* outputs of the model to prevent the generation method to crash.
Note that using `remove_invalid_values` can slow down generation.
......@@ -152,10 +153,10 @@ class GenerationConfig(PushToHubMixin):
generated. The tuple shall consist of: `(start_index, decay_factor)` where `start_index` indicates where
penalty starts and `decay_factor` represents the factor of exponential decay
suppress_tokens (`List[int]`, *optional*):
A list of tokens that will be supressed at generation. The `SupressTokens` logit processor will set their
A list of tokens that will be suppressed at generation. The `SupressTokens` logit processor will set their
log probs to `-inf` so that they are not sampled.
begin_suppress_tokens (`List[int]`, *optional*):
A list of tokens that will be supressed at the begining of the generation. The `SupressBeginTokens` logit
A list of tokens that will be suppressed at the beginning of the generation. The `SupressBeginTokens` logit
processor will set their log probs to `-inf` so that they are not sampled.
forced_decoder_ids (`List[List[int]]`, *optional*):
A list of pairs of integers which indicates a mapping from generation indices to token indices that will be
......@@ -183,8 +184,8 @@ class GenerationConfig(PushToHubMixin):
The id of the *padding* token.
bos_token_id (`int`, *optional*):
The id of the *beginning-of-sequence* token.
eos_token_id (`int`, *optional*):
The id of the *end-of-sequence* token.
eos_token_id (`Union[int, List[int]]`, *optional*):
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
> Generation parameters exclusive to encoder-decoder models
......
......@@ -15,7 +15,7 @@
import inspect
import math
from typing import Callable, Iterable, List, Optional, Tuple
from typing import Callable, Iterable, List, Optional, Tuple, Union
import numpy as np
import torch
......@@ -100,16 +100,18 @@ class MinLengthLogitsProcessor(LogitsProcessor):
Args:
min_length (`int`):
The minimum length below which the score of `eos_token_id` is set to `-float("Inf")`.
eos_token_id (`int`):
The id of the *end-of-sequence* token.
eos_token_id (`Union[int, List[int]]`):
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
"""
def __init__(self, min_length: int, eos_token_id: int):
def __init__(self, min_length: int, eos_token_id: Union[int, List[int]]):
if not isinstance(min_length, int) or min_length < 0:
raise ValueError(f"`min_length` has to be a positive integer, but is {min_length}")
if not isinstance(eos_token_id, int) or eos_token_id < 0:
raise ValueError(f"`eos_token_id` has to be a positive integer, but is {eos_token_id}")
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
if not all([isinstance(i, int) for i in eos_token_id]) or any([i < 0 for i in eos_token_id]):
raise ValueError(f"`eos_token_id` has to be a list of positive integers, but is {eos_token_id}")
self.min_length = min_length
self.eos_token_id = eos_token_id
......@@ -117,7 +119,8 @@ class MinLengthLogitsProcessor(LogitsProcessor):
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
cur_len = input_ids.shape[-1]
if cur_len < self.min_length:
scores[:, self.eos_token_id] = -float("inf")
for i in self.eos_token_id:
scores[:, i] = -float("inf")
return scores
......@@ -431,11 +434,11 @@ class NoBadWordsLogitsProcessor(LogitsProcessor):
List of list of token ids that are not allowed to be generated. In order to get the token ids of the words
that should not appear in the generated text, use `tokenizer(bad_words, add_prefix_space=True,
add_special_tokens=False).input_ids`.
eos_token_id (`int`):
The id of the *end-of-sequence* token.
eos_token_id (`Union[int, List[int]]`):
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
"""
def __init__(self, bad_words_ids: List[List[int]], eos_token_id: int):
def __init__(self, bad_words_ids: List[List[int]], eos_token_id: Union[int, List[int]]):
if not isinstance(bad_words_ids, List) or len(bad_words_ids) == 0:
raise ValueError(f"`bad_words_ids` has to be a non-empty list, but is {bad_words_ids}.")
......@@ -449,7 +452,14 @@ class NoBadWordsLogitsProcessor(LogitsProcessor):
f"Each list in `bad_words_ids` has to be a list of positive integers, but is {bad_words_ids}."
)
bad_words_ids = list(filter(lambda bad_token_seq: bad_token_seq != [eos_token_id], bad_words_ids))
if eos_token_id is None:
eos_token_id = []
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
bad_words_ids = list(
filter(lambda bad_token_seq: all([bad_token_seq != [i] for i in eos_token_id]), bad_words_ids)
)
self.bad_words_id_length_1 = []
self.bad_words_id_length_greater_than_1 = []
for word in bad_words_ids:
......@@ -664,20 +674,24 @@ class ForcedEOSTokenLogitsProcessor(LogitsProcessor):
Args:
max_length (`int`):
The maximum length of the sequence to be generated.
eos_token_id (`int`):
The id of the token to force as the last generated token when `max_length` is reached.
eos_token_id (`Union[int, List[int]]`):
The id of the token to force as the last generated token when `max_length` is reached. Optionally, use a
list to set multiple *end-of-sequence* tokens.
"""
def __init__(self, max_length: int, eos_token_id: int):
def __init__(self, max_length: int, eos_token_id: Union[int, List[int]]):
self.max_length = max_length
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
self.eos_token_id = eos_token_id
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
cur_len = input_ids.shape[-1]
if cur_len == self.max_length - 1:
num_tokens = scores.shape[1]
scores[:, [i for i in range(num_tokens) if i != self.eos_token_id]] = -float("inf")
scores[:, self.eos_token_id] = 0
scores[:, [i for i in range(num_tokens) if i not in self.eos_token_id]] = -float("inf")
for i in self.eos_token_id:
scores[:, i] = 0
return scores
......@@ -707,23 +721,26 @@ class ExponentialDecayLengthPenalty(LogitsProcessor):
exponential_decay_length_penalty (`tuple(int, float)`, *optional*):
This tuple shall consist of: `(start_index, decay_factor)` where `start_index` indicates where penalty
starts and `decay_factor` represents the factor of exponential decay
eos_token_id (`int`):
The id of the *end-of-sequence* token.
eos_token_id (`Union[int, List[int]]`):
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
input_ids_seq_length (`int`):
The length of the input sequence.
"""
def __init__(self, exponential_decay_length_penalty: Tuple, eos_token_id: int, input_ids_seq_length: int):
def __init__(
self, exponential_decay_length_penalty: Tuple, eos_token_id: Union[int, List[int]], input_ids_seq_length: int
):
self.regulation_start = exponential_decay_length_penalty[0] + input_ids_seq_length
self.regulation_factor = exponential_decay_length_penalty[1]
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
self.eos_token_id = eos_token_id
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.FloatTensor:
cur_len = input_ids.shape[-1]
if cur_len > self.regulation_start:
scores[:, self.eos_token_id] = scores[:, self.eos_token_id] * pow(
self.regulation_factor, cur_len - self.regulation_start
)
for i in self.eos_token_id:
scores[:, i] = scores[:, i] * pow(self.regulation_factor, cur_len - self.regulation_start)
return scores
......
......@@ -575,11 +575,13 @@ class GenerationMixin:
self,
inputs: torch.Tensor,
pad_token_id: Optional[int],
eos_token_id: Optional[int],
eos_token_id: Optional[Union[int, List[int]]],
) -> torch.LongTensor:
is_input_ids = len(inputs.shape) == 2 and inputs.dtype in [torch.int, torch.long]
is_pad_token_in_inputs = (pad_token_id is not None) and (pad_token_id in inputs)
is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or (pad_token_id != eos_token_id)
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or (pad_token_id not in eos_token_id)
# Check if input is input_ids and padded -> only then is attention_mask defined
if is_input_ids and is_pad_token_in_inputs and is_pad_token_not_equal_to_eos_token_id:
......@@ -902,7 +904,7 @@ class GenerationMixin:
sequences: torch.Tensor,
scores: Tuple[torch.Tensor],
beam_indices: torch.Tensor,
eos_token_id: int = None,
eos_token_id: Union[int, List[int]] = None,
):
"""compute the transition probabilities of sequences given generation
scores and beam indices"""
......@@ -1165,10 +1167,11 @@ class GenerationMixin:
"The attention mask and the pad token id were not set. As a consequence, you may observe "
"unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results."
)
logger.warning(
f"Setting `pad_token_id` to `eos_token_id`:{generation_config.eos_token_id} for open-end generation."
)
generation_config.pad_token_id = generation_config.eos_token_id
eos_token_id = generation_config.eos_token_id
if isinstance(eos_token_id, list):
eos_token_id = eos_token_id[0]
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.")
generation_config.pad_token_id = eos_token_id
# 3. Define model inputs
# inputs_tensor has to be defined
......@@ -1624,7 +1627,7 @@ class GenerationMixin:
logits_warper: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
eos_token_id: Optional[Union[int, List[int]]] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
output_scores: Optional[bool] = None,
......@@ -1708,6 +1711,8 @@ class GenerationMixin:
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
output_attentions = (
output_attentions if output_attentions is not None else self.generation_config.output_attentions
......@@ -1930,7 +1935,7 @@ class GenerationMixin:
# if eos_token was found in one sentence, set sentence to finished
if eos_token_id is not None:
unfinished_sequences = unfinished_sequences.mul((next_tokens != eos_token_id).long())
unfinished_sequences = unfinished_sequences.mul((sum(next_tokens != i for i in eos_token_id)).long())
# stop when each sentence is finished, or if we exceed the maximum length
if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
......@@ -1967,7 +1972,7 @@ class GenerationMixin:
stopping_criteria: Optional[StoppingCriteriaList] = None,
max_length: Optional[int] = None,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
eos_token_id: Optional[Union[int, List[int]]] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
output_scores: Optional[bool] = None,
......@@ -2067,6 +2072,8 @@ class GenerationMixin:
stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
output_attentions = (
output_attentions if output_attentions is not None else self.generation_config.output_attentions
......@@ -2162,7 +2169,7 @@ class GenerationMixin:
# if eos_token was found in one sentence, set sentence to finished
if eos_token_id is not None:
unfinished_sequences = unfinished_sequences.mul((next_tokens != eos_token_id).long())
unfinished_sequences = unfinished_sequences.mul((sum(next_tokens != i for i in eos_token_id)).long())
# stop when each sentence is finished, or if we exceed the maximum length
if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
......@@ -2200,7 +2207,7 @@ class GenerationMixin:
logits_warper: Optional[LogitsProcessorList] = None,
max_length: Optional[int] = None,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
eos_token_id: Optional[Union[int, List[int]]] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
output_scores: Optional[bool] = None,
......@@ -2320,6 +2327,8 @@ class GenerationMixin:
logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList()
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
output_attentions = (
output_attentions if output_attentions is not None else self.generation_config.output_attentions
......@@ -2418,7 +2427,7 @@ class GenerationMixin:
# if eos_token was found in one sentence, set sentence to finished
if eos_token_id is not None:
unfinished_sequences = unfinished_sequences.mul((next_tokens != eos_token_id).long())
unfinished_sequences = unfinished_sequences.mul((sum(next_tokens != i for i in eos_token_id)).long())
# stop when each sentence is finished, or if we exceed the maximum length
if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
......@@ -2456,7 +2465,7 @@ class GenerationMixin:
stopping_criteria: Optional[StoppingCriteriaList] = None,
max_length: Optional[int] = None,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
eos_token_id: Optional[Union[int, List[int]]] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
output_scores: Optional[bool] = None,
......@@ -2576,6 +2585,8 @@ class GenerationMixin:
warnings.warn("You don't have defined any stopping_criteria, this will likely loop forever", UserWarning)
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
output_attentions = (
output_attentions if output_attentions is not None else self.generation_config.output_attentions
......@@ -2770,7 +2781,7 @@ class GenerationMixin:
logits_warper: Optional[LogitsProcessorList] = None,
max_length: Optional[int] = None,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
eos_token_id: Optional[Union[int, List[int]]] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
output_scores: Optional[bool] = None,
......@@ -2900,6 +2911,8 @@ class GenerationMixin:
stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
output_attentions = (
output_attentions if output_attentions is not None else self.generation_config.output_attentions
......@@ -3089,7 +3102,7 @@ class GenerationMixin:
stopping_criteria: Optional[StoppingCriteriaList] = None,
max_length: Optional[int] = None,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
eos_token_id: Optional[Union[int, List[int]]] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
output_scores: Optional[bool] = None,
......@@ -3213,6 +3226,8 @@ class GenerationMixin:
stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
output_attentions = (
output_attentions if output_attentions is not None else self.generation_config.output_attentions
......@@ -3455,7 +3470,7 @@ class GenerationMixin:
stopping_criteria: Optional[StoppingCriteriaList] = None,
max_length: Optional[int] = None,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
eos_token_id: Optional[Union[int, List[int]]] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
output_scores: Optional[bool] = None,
......@@ -3586,6 +3601,8 @@ class GenerationMixin:
warnings.warn("You don't have defined any stopping_criteria, this will likely loop forever", UserWarning)
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
output_attentions = (
output_attentions if output_attentions is not None else self.generation_config.output_attentions
......
......@@ -17,7 +17,7 @@
import inspect
import unittest
from transformers import is_torch_available
from transformers import is_torch_available, pipeline
from transformers.testing_utils import require_torch, slow, torch_device
from ..test_modeling_common import floats_tensor, ids_tensor
......@@ -39,7 +39,6 @@ if is_torch_available():
SpeechEncoderDecoderModel,
T5ForConditionalGeneration,
VisionEncoderDecoderModel,
pipeline,
top_k_top_p_filtering,
)
from transformers.generation import (
......@@ -91,8 +90,9 @@ class GenerationTesterMixin:
max_length = input_ids.shape[-1] + 3
if config.eos_token_id is not None and config.pad_token_id is None:
# hack to allow generate for models such as GPT2 as is done in `generate()`
config.pad_token_id = config.eos_token_id
if isinstance(config.eos_token_id, int):
config.eos_token_id = [config.eos_token_id]
config.pad_token_id = config.eos_token_id[0]
# TransfoXL has no attention mask
if "transfoxl" in config.__class__.__name__.lower():
attention_mask = None
......@@ -3025,3 +3025,100 @@ class GenerationIntegrationTests(unittest.TestCase):
# However, valid model_kwargs are accepted
valid_model_kwargs = {"attention_mask": torch.zeros_like(input_ids)}
model.generate(input_ids, **valid_model_kwargs)
def test_eos_token_id_int_and_list_greedy_search(self):
generation_kwargs = {
"do_sample": False,
"num_beams": 1,
}
expectation = 13
tokenizer = GPT2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
text = """Hello, my dog is cute and"""
tokens = tokenizer(text, return_tensors="pt")
model = GPT2LMHeadModel.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
torch.manual_seed(0)
eos_token_id = 873
generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs)
self.assertTrue(expectation == len(generated_tokens[0]))
torch.manual_seed(0)
eos_token_id = [873]
generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs)
self.assertTrue(expectation == len(generated_tokens[0]))
def test_eos_token_id_int_and_list_contrastive_search(self):
generation_kwargs = {
"do_sample": False,
"num_beams": 1,
"penalty_alpha": 0.6,
"top_k": 4,
}
expectation = 17
tokenizer = GPT2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
text = """Hello, my dog is cute and"""
tokens = tokenizer(text, return_tensors="pt")
model = GPT2LMHeadModel.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
torch.manual_seed(0)
eos_token_id = 225
generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs)
self.assertTrue(expectation == len(generated_tokens[0]))
torch.manual_seed(0)
eos_token_id = [225]
generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs)
self.assertTrue(expectation == len(generated_tokens[0]))
def test_eos_token_id_int_and_list_top_k_top_sampling(self):
generation_kwargs = {
"do_sample": True,
"num_beams": 1,
"top_p": 0.7,
"top_k": 10,
"temperature": 0.7,
}
expectation = 15
tokenizer = GPT2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
text = """Hello, my dog is cute and"""
tokens = tokenizer(text, return_tensors="pt")
model = GPT2LMHeadModel.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
torch.manual_seed(0)
eos_token_id = 846
generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs)
self.assertTrue(expectation == len(generated_tokens[0]))
torch.manual_seed(0)
eos_token_id = [846]
generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs)
self.assertTrue(expectation == len(generated_tokens[0]))
def test_eos_token_id_int_and_list_beam_search(self):
generation_kwargs = {
"do_sample": False,
"num_beams": 3,
}
expectation = 13
tokenizer = GPT2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
text = """Hello, my dog is cute and"""
tokens = tokenizer(text, return_tensors="pt")
model = GPT2LMHeadModel.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
torch.manual_seed(0)
eos_token_id = 873
generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs)
self.assertTrue(expectation == len(generated_tokens[0]))
torch.manual_seed(0)
eos_token_id = [873]
generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs)
self.assertTrue(expectation == len(generated_tokens[0]))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment