"utils/tests_fetcher.py" did not exist on "11edecd75379657c1929615fa729c8f3c07dcafc"
Unverified Commit 77bf3fe7 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Generate] Add save mode logits processor to remove nans and infs if necessary (#10769)

* push

* finish

* finish

* make fix copies

* change name
parent 9f8fa4e9
......@@ -151,6 +151,16 @@ generation.
.. autoclass:: transformers.HammingDiversityLogitsProcessor
:members: __call__
.. autoclass:: transformers.ForcedBOSTokenLogitsProcessor
:members: __call__
.. autoclass:: transformers.ForcedEOSTokenLogitsProcessor
:members: __call__
.. autoclass:: transformers.InfNanRemoveLogitsProcessor
:members: __call__
StoppingCriteria
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
......
......@@ -369,7 +369,10 @@ if is_torch_available():
]
_import_structure["generation_beam_search"] = ["BeamScorer", "BeamSearchScorer"]
_import_structure["generation_logits_process"] = [
"ForcedBOSTokenLogitsProcessor",
"ForcedEOSTokenLogitsProcessor",
"HammingDiversityLogitsProcessor",
"InfNanRemoveLogitsProcessor",
"LogitsProcessor",
"LogitsProcessorList",
"LogitsWarper",
......@@ -1560,7 +1563,10 @@ if TYPE_CHECKING:
)
from .generation_beam_search import BeamScorer, BeamSearchScorer
from .generation_logits_process import (
ForcedBOSTokenLogitsProcessor,
ForcedEOSTokenLogitsProcessor,
HammingDiversityLogitsProcessor,
InfNanRemoveLogitsProcessor,
LogitsProcessor,
LogitsProcessorList,
LogitsWarper,
......
......@@ -134,6 +134,9 @@ class PretrainedConfig(object):
<../model_doc/mbart>` where the first generated token needs to be the target language token.
- **forced_eos_token_id** (:obj:`int`, `optional`) -- The id of the token to force as the last generated token
when :obj:`max_length` is reached.
- **remove_invalid_values** (:obj:`bool`, `optional`) -- Whether to remove possible `nan` and `inf` outputs of
the model to prevent the generation method to crash. Note that using ``remove_invalid_values`` can slow down
generation.
Parameters for fine-tuning tasks
......@@ -219,6 +222,7 @@ class PretrainedConfig(object):
self.return_dict_in_generate = kwargs.pop("return_dict_in_generate", False)
self.forced_bos_token_id = kwargs.pop("forced_bos_token_id", None)
self.forced_eos_token_id = kwargs.pop("forced_eos_token_id", None)
self.remove_invalid_values = kwargs.pop("remove_invalid_values", False)
# Fine-tuning task arguments
self.architectures = kwargs.pop("architectures", None)
......
......@@ -566,3 +566,20 @@ class ForcedEOSTokenLogitsProcessor(LogitsProcessor):
scores[:, [i for i in range(num_tokens) if i != self.eos_token_id]] = -float("inf")
scores[:, self.eos_token_id] = 0
return scores
class InfNanRemoveLogitsProcessor(LogitsProcessor):
r"""
:class:`~transformers.LogitsProcessor` that removes all :obj:`nan` and :obj:`inf` values to avoid the generation
method to fail. Note that using the logits processor should only be used if necessary since it can slow down the
generation method. :obj:`max_length` is reached.
"""
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
# set all nan values to 0.0
scores[scores != scores] = 0.0
# set all inf values to max possible value
scores[scores == float("inf")] = torch.finfo(scores.dtype).max
return scores
......@@ -27,6 +27,7 @@ from .generation_logits_process import (
ForcedBOSTokenLogitsProcessor,
ForcedEOSTokenLogitsProcessor,
HammingDiversityLogitsProcessor,
InfNanRemoveLogitsProcessor,
LogitsProcessorList,
MinLengthLogitsProcessor,
NoBadWordsLogitsProcessor,
......@@ -581,6 +582,7 @@ class GenerationMixin:
num_beams: int,
num_beam_groups: int,
diversity_penalty: float,
remove_invalid_values: bool,
) -> LogitsProcessorList:
"""
This class returns a :obj:`~transformers.LogitsProcessorList` list object that contains all relevant
......@@ -607,6 +609,9 @@ class GenerationMixin:
forced_eos_token_id = (
forced_eos_token_id if forced_eos_token_id is not None else self.config.forced_eos_token_id
)
remove_invalid_values = (
remove_invalid_values if remove_invalid_values is not None else self.config.remove_invalid_values
)
# instantiate processors list
processors = LogitsProcessorList()
......@@ -639,6 +644,8 @@ class GenerationMixin:
processors.append(ForcedBOSTokenLogitsProcessor(forced_bos_token_id))
if forced_eos_token_id is not None:
processors.append(ForcedEOSTokenLogitsProcessor(max_length, forced_eos_token_id))
if remove_invalid_values is True:
processors.append(InfNanRemoveLogitsProcessor())
return processors
def _get_stopping_criteria(
......@@ -687,6 +694,7 @@ class GenerationMixin:
return_dict_in_generate: Optional[bool] = None,
forced_bos_token_id: Optional[int] = None,
forced_eos_token_id: Optional[int] = None,
remove_invalid_values: Optional[bool] = None,
**model_kwargs,
) -> Union[GreedySearchOutput, SampleOutput, BeamSearchOutput, BeamSampleOutput, torch.LongTensor]:
r"""
......@@ -789,6 +797,9 @@ class GenerationMixin:
needs to be the target language token.
forced_eos_token_id (:obj:`int`, `optional`):
The id of the token to force as the last generated token when :obj:`max_length` is reached.
remove_invalid_values (:obj:`bool`, `optional`):
Whether to remove possible `nan` and `inf` outputs of the model to prevent the generation method to
crash. Note that using ``remove_invalid_values`` can slow down generation.
model_kwargs:
Additional model specific kwargs will be forwarded to the :obj:`forward` function of the model. If the
......@@ -965,6 +976,7 @@ class GenerationMixin:
num_beams=num_beams,
num_beam_groups=num_beam_groups,
diversity_penalty=diversity_penalty,
remove_invalid_values=remove_invalid_values,
)
stopping_criteria = self._get_stopping_criteria(
......@@ -1511,6 +1523,7 @@ class GenerationMixin:
# sample
probs = F.softmax(next_token_scores, dim=-1)
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
# add code that transfomers next_tokens to tokens_to_add
......@@ -2026,6 +2039,7 @@ class GenerationMixin:
next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size)
probs = F.softmax(next_token_scores, dim=-1)
next_tokens = torch.multinomial(probs, num_samples=2 * num_beams)
next_token_scores = torch.gather(next_token_scores, -1, next_tokens)
......
......@@ -1316,6 +1316,7 @@ class RagTokenForGeneration(RagPreTrainedModel):
prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]] = None,
forced_bos_token_id: Optional[int] = None,
forced_eos_token_id: Optional[int] = None,
remove_invalid_values: Optional[bool] = None,
**model_kwargs
):
"""
......@@ -1412,6 +1413,9 @@ class RagTokenForGeneration(RagPreTrainedModel):
needs to be the target language token.
forced_eos_token_id (:obj:`int`, `optional`):
The id of the token to force as the last generated token when :obj:`max_length` is reached.
remove_invalid_values (:obj:`bool`, `optional`):
Whether to remove possible `nan` and `inf` outputs of the model to prevent the generation method to
crash. Note that using ``remove_invalid_values`` can slow down generation.
Return:
:obj:`torch.LongTensor` of shape :obj:`(batch_size * num_return_sequences, sequence_length)`: The generated
......@@ -1435,6 +1439,9 @@ class RagTokenForGeneration(RagPreTrainedModel):
if decoder_start_token_id is not None
else self.config.generator.decoder_start_token_id
)
remove_invalid_values = (
remove_invalid_values if remove_invalid_values is not None else self.config.remove_invalid_values
)
# retrieve docs
if self.retriever is not None and context_input_ids is None:
......@@ -1515,6 +1522,7 @@ class RagTokenForGeneration(RagPreTrainedModel):
num_beams=num_beams,
num_beam_groups=num_beam_groups,
diversity_penalty=diversity_penalty,
remove_invalid_values=remove_invalid_values,
)
if num_beams == 1:
......
......@@ -123,11 +123,26 @@ class BeamSearchScorer:
requires_pytorch(self)
class ForcedBOSTokenLogitsProcessor:
def __init__(self, *args, **kwargs):
requires_pytorch(self)
class ForcedEOSTokenLogitsProcessor:
def __init__(self, *args, **kwargs):
requires_pytorch(self)
class HammingDiversityLogitsProcessor:
def __init__(self, *args, **kwargs):
requires_pytorch(self)
class InfNanRemoveLogitsProcessor:
def __init__(self, *args, **kwargs):
requires_pytorch(self)
class LogitsProcessor:
def __init__(self, *args, **kwargs):
requires_pytorch(self)
......
......@@ -31,6 +31,7 @@ if is_torch_available():
ForcedBOSTokenLogitsProcessor,
ForcedEOSTokenLogitsProcessor,
HammingDiversityLogitsProcessor,
InfNanRemoveLogitsProcessor,
LogitsProcessorList,
MinLengthLogitsProcessor,
NoBadWordsLogitsProcessor,
......@@ -436,3 +437,24 @@ class LogitsProcessorTest(unittest.TestCase):
scores = self._get_uniform_logits(batch_size, vocab_size)
scores = logits_processor(input_ids, scores)
self.assertFalse(torch.isinf(scores).any())
def test_remove_nan_inf_logits_processor(self):
scores = torch.tensor(
[[0.0, 0.7, 0.8, float("nan")], [0.1, float("inf"), 0.3, float("-inf")]], device=torch_device
)
input_ids = ids_tensor((2, 4), vocab_size=20)
logits_processor = InfNanRemoveLogitsProcessor()
scores = logits_processor(input_ids, scores)
self.assertTrue(
torch.allclose(
scores,
torch.tensor(
[[0.0, 0.7, 0.8, 0.0], [0.1, torch.finfo(scores.dtype).max, 0.3, float("-inf")]],
device=torch_device,
),
atol=1e-6,
)
)
......@@ -29,6 +29,7 @@ if is_torch_available():
ForcedBOSTokenLogitsProcessor,
ForcedEOSTokenLogitsProcessor,
HammingDiversityLogitsProcessor,
InfNanRemoveLogitsProcessor,
LogitsProcessorList,
MinLengthLogitsProcessor,
NoBadWordsLogitsProcessor,
......@@ -229,6 +230,7 @@ class GenerationTesterMixin:
output_hidden_states=output_hidden_states,
output_scores=output_scores,
return_dict_in_generate=return_dict_in_generate,
remove_invalid_values=True,
**logits_process_kwargs,
)
......@@ -284,6 +286,7 @@ class GenerationTesterMixin:
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict_in_generate=return_dict_in_generate,
remove_invalid_values=True,
**logits_warper_kwargs,
**process_kwargs,
)
......@@ -305,6 +308,10 @@ class GenerationTesterMixin:
attention_mask_clone = attention_mask.repeat_interleave(num_return_sequences, dim=0)
input_ids_clone = input_ids.repeat_interleave(num_return_sequences, dim=0)
# prevent flaky generation test failures
logits_processor.append(InfNanRemoveLogitsProcessor())
with torch.no_grad():
with torch.no_grad():
output_sample = model.sample(
input_ids_clone,
......@@ -344,6 +351,7 @@ class GenerationTesterMixin:
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict_in_generate=return_dict_in_generate,
remove_invalid_values=True,
**beam_kwargs,
**logits_process_kwargs,
)
......@@ -406,6 +414,7 @@ class GenerationTesterMixin:
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict_in_generate=return_dict_in_generate,
remove_invalid_values=True,
**beam_kwargs,
**logits_warper_kwargs,
)
......@@ -424,6 +433,10 @@ class GenerationTesterMixin:
else:
attention_mask = attention_mask.repeat_interleave(beam_scorer.num_beams * num_return_sequences, dim=0)
# prevent flaky generation test failures
logits_processor = LogitsProcessorList()
logits_processor.append(InfNanRemoveLogitsProcessor())
torch.manual_seed(0)
with torch.no_grad():
output_beam_sample = model.beam_sample(
......@@ -432,6 +445,7 @@ class GenerationTesterMixin:
max_length=max_length,
attention_mask=attention_mask,
logits_warper=logits_warper,
logits_processor=logits_processor,
output_scores=output_scores,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
......@@ -465,6 +479,7 @@ class GenerationTesterMixin:
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict_in_generate=return_dict_in_generate,
remove_invalid_values=True,
**beam_kwargs,
**logits_process_kwargs,
)
......@@ -936,6 +951,7 @@ class GenerationTesterMixin:
output_ids_generate = model.generate(
do_sample=False,
max_length=max_length,
remove_invalid_values=True,
)
self.assertIsNotNone(output_ids_generate)
......@@ -1309,7 +1325,12 @@ class GenerationIntegrationTests(unittest.TestCase):
input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
outputs = bart_model.generate(
input_ids, num_beams=4, num_return_sequences=2, num_beam_groups=4, diversity_penalty=2.0
input_ids,
num_beams=4,
num_return_sequences=2,
num_beam_groups=4,
diversity_penalty=2.0,
remove_invalid_values=True,
)
generated_text = bart_tokenizer.batch_decode(outputs, skip_special_tokens=True)
......@@ -1359,6 +1380,7 @@ class GenerationIntegrationTests(unittest.TestCase):
decoder_start_token_id=bart_model.config.decoder_start_token_id,
bos_token_id=bart_model.config.bos_token_id,
)
with torch.no_grad():
bart_model.sample(
input_ids,
max_length=max_length,
......@@ -1463,6 +1485,7 @@ class GenerationIntegrationTests(unittest.TestCase):
# Sample
with self.assertWarns(UserWarning):
with torch.no_grad():
bart_model.sample(
input_ids,
max_length=max_length,
......@@ -1480,6 +1503,7 @@ class GenerationIntegrationTests(unittest.TestCase):
device=torch_device,
)
with self.assertWarns(UserWarning):
with torch.no_grad():
bart_model.beam_search(
input_ids,
num_beams=num_beams,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment