"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "eff274d629c95fca459969b530b4ad0da5563918"
Unverified Commit 77bf3fe7 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Generate] Add save mode logits processor to remove nans and infs if necessary (#10769)

* push

* finish

* finish

* make fix copies

* change name
parent 9f8fa4e9
...@@ -151,6 +151,16 @@ generation. ...@@ -151,6 +151,16 @@ generation.
.. autoclass:: transformers.HammingDiversityLogitsProcessor .. autoclass:: transformers.HammingDiversityLogitsProcessor
:members: __call__ :members: __call__
.. autoclass:: transformers.ForcedBOSTokenLogitsProcessor
:members: __call__
.. autoclass:: transformers.ForcedEOSTokenLogitsProcessor
:members: __call__
.. autoclass:: transformers.InfNanRemoveLogitsProcessor
:members: __call__
StoppingCriteria StoppingCriteria
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
......
...@@ -369,7 +369,10 @@ if is_torch_available(): ...@@ -369,7 +369,10 @@ if is_torch_available():
] ]
_import_structure["generation_beam_search"] = ["BeamScorer", "BeamSearchScorer"] _import_structure["generation_beam_search"] = ["BeamScorer", "BeamSearchScorer"]
_import_structure["generation_logits_process"] = [ _import_structure["generation_logits_process"] = [
"ForcedBOSTokenLogitsProcessor",
"ForcedEOSTokenLogitsProcessor",
"HammingDiversityLogitsProcessor", "HammingDiversityLogitsProcessor",
"InfNanRemoveLogitsProcessor",
"LogitsProcessor", "LogitsProcessor",
"LogitsProcessorList", "LogitsProcessorList",
"LogitsWarper", "LogitsWarper",
...@@ -1560,7 +1563,10 @@ if TYPE_CHECKING: ...@@ -1560,7 +1563,10 @@ if TYPE_CHECKING:
) )
from .generation_beam_search import BeamScorer, BeamSearchScorer from .generation_beam_search import BeamScorer, BeamSearchScorer
from .generation_logits_process import ( from .generation_logits_process import (
ForcedBOSTokenLogitsProcessor,
ForcedEOSTokenLogitsProcessor,
HammingDiversityLogitsProcessor, HammingDiversityLogitsProcessor,
InfNanRemoveLogitsProcessor,
LogitsProcessor, LogitsProcessor,
LogitsProcessorList, LogitsProcessorList,
LogitsWarper, LogitsWarper,
......
...@@ -134,6 +134,9 @@ class PretrainedConfig(object): ...@@ -134,6 +134,9 @@ class PretrainedConfig(object):
<../model_doc/mbart>` where the first generated token needs to be the target language token. <../model_doc/mbart>` where the first generated token needs to be the target language token.
- **forced_eos_token_id** (:obj:`int`, `optional`) -- The id of the token to force as the last generated token - **forced_eos_token_id** (:obj:`int`, `optional`) -- The id of the token to force as the last generated token
when :obj:`max_length` is reached. when :obj:`max_length` is reached.
- **remove_invalid_values** (:obj:`bool`, `optional`) -- Whether to remove possible `nan` and `inf` outputs of
the model to prevent the generation method to crash. Note that using ``remove_invalid_values`` can slow down
generation.
Parameters for fine-tuning tasks Parameters for fine-tuning tasks
...@@ -219,6 +222,7 @@ class PretrainedConfig(object): ...@@ -219,6 +222,7 @@ class PretrainedConfig(object):
self.return_dict_in_generate = kwargs.pop("return_dict_in_generate", False) self.return_dict_in_generate = kwargs.pop("return_dict_in_generate", False)
self.forced_bos_token_id = kwargs.pop("forced_bos_token_id", None) self.forced_bos_token_id = kwargs.pop("forced_bos_token_id", None)
self.forced_eos_token_id = kwargs.pop("forced_eos_token_id", None) self.forced_eos_token_id = kwargs.pop("forced_eos_token_id", None)
self.remove_invalid_values = kwargs.pop("remove_invalid_values", False)
# Fine-tuning task arguments # Fine-tuning task arguments
self.architectures = kwargs.pop("architectures", None) self.architectures = kwargs.pop("architectures", None)
......
...@@ -566,3 +566,20 @@ class ForcedEOSTokenLogitsProcessor(LogitsProcessor): ...@@ -566,3 +566,20 @@ class ForcedEOSTokenLogitsProcessor(LogitsProcessor):
scores[:, [i for i in range(num_tokens) if i != self.eos_token_id]] = -float("inf") scores[:, [i for i in range(num_tokens) if i != self.eos_token_id]] = -float("inf")
scores[:, self.eos_token_id] = 0 scores[:, self.eos_token_id] = 0
return scores return scores
class InfNanRemoveLogitsProcessor(LogitsProcessor):
r"""
:class:`~transformers.LogitsProcessor` that removes all :obj:`nan` and :obj:`inf` values to avoid the generation
method to fail. Note that using the logits processor should only be used if necessary since it can slow down the
generation method. :obj:`max_length` is reached.
"""
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
# set all nan values to 0.0
scores[scores != scores] = 0.0
# set all inf values to max possible value
scores[scores == float("inf")] = torch.finfo(scores.dtype).max
return scores
...@@ -27,6 +27,7 @@ from .generation_logits_process import ( ...@@ -27,6 +27,7 @@ from .generation_logits_process import (
ForcedBOSTokenLogitsProcessor, ForcedBOSTokenLogitsProcessor,
ForcedEOSTokenLogitsProcessor, ForcedEOSTokenLogitsProcessor,
HammingDiversityLogitsProcessor, HammingDiversityLogitsProcessor,
InfNanRemoveLogitsProcessor,
LogitsProcessorList, LogitsProcessorList,
MinLengthLogitsProcessor, MinLengthLogitsProcessor,
NoBadWordsLogitsProcessor, NoBadWordsLogitsProcessor,
...@@ -581,6 +582,7 @@ class GenerationMixin: ...@@ -581,6 +582,7 @@ class GenerationMixin:
num_beams: int, num_beams: int,
num_beam_groups: int, num_beam_groups: int,
diversity_penalty: float, diversity_penalty: float,
remove_invalid_values: bool,
) -> LogitsProcessorList: ) -> LogitsProcessorList:
""" """
This class returns a :obj:`~transformers.LogitsProcessorList` list object that contains all relevant This class returns a :obj:`~transformers.LogitsProcessorList` list object that contains all relevant
...@@ -607,6 +609,9 @@ class GenerationMixin: ...@@ -607,6 +609,9 @@ class GenerationMixin:
forced_eos_token_id = ( forced_eos_token_id = (
forced_eos_token_id if forced_eos_token_id is not None else self.config.forced_eos_token_id forced_eos_token_id if forced_eos_token_id is not None else self.config.forced_eos_token_id
) )
remove_invalid_values = (
remove_invalid_values if remove_invalid_values is not None else self.config.remove_invalid_values
)
# instantiate processors list # instantiate processors list
processors = LogitsProcessorList() processors = LogitsProcessorList()
...@@ -639,6 +644,8 @@ class GenerationMixin: ...@@ -639,6 +644,8 @@ class GenerationMixin:
processors.append(ForcedBOSTokenLogitsProcessor(forced_bos_token_id)) processors.append(ForcedBOSTokenLogitsProcessor(forced_bos_token_id))
if forced_eos_token_id is not None: if forced_eos_token_id is not None:
processors.append(ForcedEOSTokenLogitsProcessor(max_length, forced_eos_token_id)) processors.append(ForcedEOSTokenLogitsProcessor(max_length, forced_eos_token_id))
if remove_invalid_values is True:
processors.append(InfNanRemoveLogitsProcessor())
return processors return processors
def _get_stopping_criteria( def _get_stopping_criteria(
...@@ -687,6 +694,7 @@ class GenerationMixin: ...@@ -687,6 +694,7 @@ class GenerationMixin:
return_dict_in_generate: Optional[bool] = None, return_dict_in_generate: Optional[bool] = None,
forced_bos_token_id: Optional[int] = None, forced_bos_token_id: Optional[int] = None,
forced_eos_token_id: Optional[int] = None, forced_eos_token_id: Optional[int] = None,
remove_invalid_values: Optional[bool] = None,
**model_kwargs, **model_kwargs,
) -> Union[GreedySearchOutput, SampleOutput, BeamSearchOutput, BeamSampleOutput, torch.LongTensor]: ) -> Union[GreedySearchOutput, SampleOutput, BeamSearchOutput, BeamSampleOutput, torch.LongTensor]:
r""" r"""
...@@ -789,6 +797,9 @@ class GenerationMixin: ...@@ -789,6 +797,9 @@ class GenerationMixin:
needs to be the target language token. needs to be the target language token.
forced_eos_token_id (:obj:`int`, `optional`): forced_eos_token_id (:obj:`int`, `optional`):
The id of the token to force as the last generated token when :obj:`max_length` is reached. The id of the token to force as the last generated token when :obj:`max_length` is reached.
remove_invalid_values (:obj:`bool`, `optional`):
Whether to remove possible `nan` and `inf` outputs of the model to prevent the generation method to
crash. Note that using ``remove_invalid_values`` can slow down generation.
model_kwargs: model_kwargs:
Additional model specific kwargs will be forwarded to the :obj:`forward` function of the model. If the Additional model specific kwargs will be forwarded to the :obj:`forward` function of the model. If the
...@@ -965,6 +976,7 @@ class GenerationMixin: ...@@ -965,6 +976,7 @@ class GenerationMixin:
num_beams=num_beams, num_beams=num_beams,
num_beam_groups=num_beam_groups, num_beam_groups=num_beam_groups,
diversity_penalty=diversity_penalty, diversity_penalty=diversity_penalty,
remove_invalid_values=remove_invalid_values,
) )
stopping_criteria = self._get_stopping_criteria( stopping_criteria = self._get_stopping_criteria(
...@@ -1511,6 +1523,7 @@ class GenerationMixin: ...@@ -1511,6 +1523,7 @@ class GenerationMixin:
# sample # sample
probs = F.softmax(next_token_scores, dim=-1) probs = F.softmax(next_token_scores, dim=-1)
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
# add code that transfomers next_tokens to tokens_to_add # add code that transfomers next_tokens to tokens_to_add
...@@ -2026,6 +2039,7 @@ class GenerationMixin: ...@@ -2026,6 +2039,7 @@ class GenerationMixin:
next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size) next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size)
probs = F.softmax(next_token_scores, dim=-1) probs = F.softmax(next_token_scores, dim=-1)
next_tokens = torch.multinomial(probs, num_samples=2 * num_beams) next_tokens = torch.multinomial(probs, num_samples=2 * num_beams)
next_token_scores = torch.gather(next_token_scores, -1, next_tokens) next_token_scores = torch.gather(next_token_scores, -1, next_tokens)
......
...@@ -1316,6 +1316,7 @@ class RagTokenForGeneration(RagPreTrainedModel): ...@@ -1316,6 +1316,7 @@ class RagTokenForGeneration(RagPreTrainedModel):
prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]] = None, prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]] = None,
forced_bos_token_id: Optional[int] = None, forced_bos_token_id: Optional[int] = None,
forced_eos_token_id: Optional[int] = None, forced_eos_token_id: Optional[int] = None,
remove_invalid_values: Optional[bool] = None,
**model_kwargs **model_kwargs
): ):
""" """
...@@ -1412,6 +1413,9 @@ class RagTokenForGeneration(RagPreTrainedModel): ...@@ -1412,6 +1413,9 @@ class RagTokenForGeneration(RagPreTrainedModel):
needs to be the target language token. needs to be the target language token.
forced_eos_token_id (:obj:`int`, `optional`): forced_eos_token_id (:obj:`int`, `optional`):
The id of the token to force as the last generated token when :obj:`max_length` is reached. The id of the token to force as the last generated token when :obj:`max_length` is reached.
remove_invalid_values (:obj:`bool`, `optional`):
Whether to remove possible `nan` and `inf` outputs of the model to prevent the generation method to
crash. Note that using ``remove_invalid_values`` can slow down generation.
Return: Return:
:obj:`torch.LongTensor` of shape :obj:`(batch_size * num_return_sequences, sequence_length)`: The generated :obj:`torch.LongTensor` of shape :obj:`(batch_size * num_return_sequences, sequence_length)`: The generated
...@@ -1435,6 +1439,9 @@ class RagTokenForGeneration(RagPreTrainedModel): ...@@ -1435,6 +1439,9 @@ class RagTokenForGeneration(RagPreTrainedModel):
if decoder_start_token_id is not None if decoder_start_token_id is not None
else self.config.generator.decoder_start_token_id else self.config.generator.decoder_start_token_id
) )
remove_invalid_values = (
remove_invalid_values if remove_invalid_values is not None else self.config.remove_invalid_values
)
# retrieve docs # retrieve docs
if self.retriever is not None and context_input_ids is None: if self.retriever is not None and context_input_ids is None:
...@@ -1515,6 +1522,7 @@ class RagTokenForGeneration(RagPreTrainedModel): ...@@ -1515,6 +1522,7 @@ class RagTokenForGeneration(RagPreTrainedModel):
num_beams=num_beams, num_beams=num_beams,
num_beam_groups=num_beam_groups, num_beam_groups=num_beam_groups,
diversity_penalty=diversity_penalty, diversity_penalty=diversity_penalty,
remove_invalid_values=remove_invalid_values,
) )
if num_beams == 1: if num_beams == 1:
......
...@@ -123,11 +123,26 @@ class BeamSearchScorer: ...@@ -123,11 +123,26 @@ class BeamSearchScorer:
requires_pytorch(self) requires_pytorch(self)
class ForcedBOSTokenLogitsProcessor:
def __init__(self, *args, **kwargs):
requires_pytorch(self)
class ForcedEOSTokenLogitsProcessor:
def __init__(self, *args, **kwargs):
requires_pytorch(self)
class HammingDiversityLogitsProcessor: class HammingDiversityLogitsProcessor:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_pytorch(self) requires_pytorch(self)
class InfNanRemoveLogitsProcessor:
def __init__(self, *args, **kwargs):
requires_pytorch(self)
class LogitsProcessor: class LogitsProcessor:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_pytorch(self) requires_pytorch(self)
......
...@@ -31,6 +31,7 @@ if is_torch_available(): ...@@ -31,6 +31,7 @@ if is_torch_available():
ForcedBOSTokenLogitsProcessor, ForcedBOSTokenLogitsProcessor,
ForcedEOSTokenLogitsProcessor, ForcedEOSTokenLogitsProcessor,
HammingDiversityLogitsProcessor, HammingDiversityLogitsProcessor,
InfNanRemoveLogitsProcessor,
LogitsProcessorList, LogitsProcessorList,
MinLengthLogitsProcessor, MinLengthLogitsProcessor,
NoBadWordsLogitsProcessor, NoBadWordsLogitsProcessor,
...@@ -436,3 +437,24 @@ class LogitsProcessorTest(unittest.TestCase): ...@@ -436,3 +437,24 @@ class LogitsProcessorTest(unittest.TestCase):
scores = self._get_uniform_logits(batch_size, vocab_size) scores = self._get_uniform_logits(batch_size, vocab_size)
scores = logits_processor(input_ids, scores) scores = logits_processor(input_ids, scores)
self.assertFalse(torch.isinf(scores).any()) self.assertFalse(torch.isinf(scores).any())
def test_remove_nan_inf_logits_processor(self):
scores = torch.tensor(
[[0.0, 0.7, 0.8, float("nan")], [0.1, float("inf"), 0.3, float("-inf")]], device=torch_device
)
input_ids = ids_tensor((2, 4), vocab_size=20)
logits_processor = InfNanRemoveLogitsProcessor()
scores = logits_processor(input_ids, scores)
self.assertTrue(
torch.allclose(
scores,
torch.tensor(
[[0.0, 0.7, 0.8, 0.0], [0.1, torch.finfo(scores.dtype).max, 0.3, float("-inf")]],
device=torch_device,
),
atol=1e-6,
)
)
...@@ -29,6 +29,7 @@ if is_torch_available(): ...@@ -29,6 +29,7 @@ if is_torch_available():
ForcedBOSTokenLogitsProcessor, ForcedBOSTokenLogitsProcessor,
ForcedEOSTokenLogitsProcessor, ForcedEOSTokenLogitsProcessor,
HammingDiversityLogitsProcessor, HammingDiversityLogitsProcessor,
InfNanRemoveLogitsProcessor,
LogitsProcessorList, LogitsProcessorList,
MinLengthLogitsProcessor, MinLengthLogitsProcessor,
NoBadWordsLogitsProcessor, NoBadWordsLogitsProcessor,
...@@ -229,6 +230,7 @@ class GenerationTesterMixin: ...@@ -229,6 +230,7 @@ class GenerationTesterMixin:
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
output_scores=output_scores, output_scores=output_scores,
return_dict_in_generate=return_dict_in_generate, return_dict_in_generate=return_dict_in_generate,
remove_invalid_values=True,
**logits_process_kwargs, **logits_process_kwargs,
) )
...@@ -284,6 +286,7 @@ class GenerationTesterMixin: ...@@ -284,6 +286,7 @@ class GenerationTesterMixin:
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict_in_generate=return_dict_in_generate, return_dict_in_generate=return_dict_in_generate,
remove_invalid_values=True,
**logits_warper_kwargs, **logits_warper_kwargs,
**process_kwargs, **process_kwargs,
) )
...@@ -305,6 +308,10 @@ class GenerationTesterMixin: ...@@ -305,6 +308,10 @@ class GenerationTesterMixin:
attention_mask_clone = attention_mask.repeat_interleave(num_return_sequences, dim=0) attention_mask_clone = attention_mask.repeat_interleave(num_return_sequences, dim=0)
input_ids_clone = input_ids.repeat_interleave(num_return_sequences, dim=0) input_ids_clone = input_ids.repeat_interleave(num_return_sequences, dim=0)
# prevent flaky generation test failures
logits_processor.append(InfNanRemoveLogitsProcessor())
with torch.no_grad():
with torch.no_grad(): with torch.no_grad():
output_sample = model.sample( output_sample = model.sample(
input_ids_clone, input_ids_clone,
...@@ -344,6 +351,7 @@ class GenerationTesterMixin: ...@@ -344,6 +351,7 @@ class GenerationTesterMixin:
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict_in_generate=return_dict_in_generate, return_dict_in_generate=return_dict_in_generate,
remove_invalid_values=True,
**beam_kwargs, **beam_kwargs,
**logits_process_kwargs, **logits_process_kwargs,
) )
...@@ -406,6 +414,7 @@ class GenerationTesterMixin: ...@@ -406,6 +414,7 @@ class GenerationTesterMixin:
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict_in_generate=return_dict_in_generate, return_dict_in_generate=return_dict_in_generate,
remove_invalid_values=True,
**beam_kwargs, **beam_kwargs,
**logits_warper_kwargs, **logits_warper_kwargs,
) )
...@@ -424,6 +433,10 @@ class GenerationTesterMixin: ...@@ -424,6 +433,10 @@ class GenerationTesterMixin:
else: else:
attention_mask = attention_mask.repeat_interleave(beam_scorer.num_beams * num_return_sequences, dim=0) attention_mask = attention_mask.repeat_interleave(beam_scorer.num_beams * num_return_sequences, dim=0)
# prevent flaky generation test failures
logits_processor = LogitsProcessorList()
logits_processor.append(InfNanRemoveLogitsProcessor())
torch.manual_seed(0) torch.manual_seed(0)
with torch.no_grad(): with torch.no_grad():
output_beam_sample = model.beam_sample( output_beam_sample = model.beam_sample(
...@@ -432,6 +445,7 @@ class GenerationTesterMixin: ...@@ -432,6 +445,7 @@ class GenerationTesterMixin:
max_length=max_length, max_length=max_length,
attention_mask=attention_mask, attention_mask=attention_mask,
logits_warper=logits_warper, logits_warper=logits_warper,
logits_processor=logits_processor,
output_scores=output_scores, output_scores=output_scores,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
...@@ -465,6 +479,7 @@ class GenerationTesterMixin: ...@@ -465,6 +479,7 @@ class GenerationTesterMixin:
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict_in_generate=return_dict_in_generate, return_dict_in_generate=return_dict_in_generate,
remove_invalid_values=True,
**beam_kwargs, **beam_kwargs,
**logits_process_kwargs, **logits_process_kwargs,
) )
...@@ -936,6 +951,7 @@ class GenerationTesterMixin: ...@@ -936,6 +951,7 @@ class GenerationTesterMixin:
output_ids_generate = model.generate( output_ids_generate = model.generate(
do_sample=False, do_sample=False,
max_length=max_length, max_length=max_length,
remove_invalid_values=True,
) )
self.assertIsNotNone(output_ids_generate) self.assertIsNotNone(output_ids_generate)
...@@ -1309,7 +1325,12 @@ class GenerationIntegrationTests(unittest.TestCase): ...@@ -1309,7 +1325,12 @@ class GenerationIntegrationTests(unittest.TestCase):
input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device) input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
outputs = bart_model.generate( outputs = bart_model.generate(
input_ids, num_beams=4, num_return_sequences=2, num_beam_groups=4, diversity_penalty=2.0 input_ids,
num_beams=4,
num_return_sequences=2,
num_beam_groups=4,
diversity_penalty=2.0,
remove_invalid_values=True,
) )
generated_text = bart_tokenizer.batch_decode(outputs, skip_special_tokens=True) generated_text = bart_tokenizer.batch_decode(outputs, skip_special_tokens=True)
...@@ -1359,6 +1380,7 @@ class GenerationIntegrationTests(unittest.TestCase): ...@@ -1359,6 +1380,7 @@ class GenerationIntegrationTests(unittest.TestCase):
decoder_start_token_id=bart_model.config.decoder_start_token_id, decoder_start_token_id=bart_model.config.decoder_start_token_id,
bos_token_id=bart_model.config.bos_token_id, bos_token_id=bart_model.config.bos_token_id,
) )
with torch.no_grad():
bart_model.sample( bart_model.sample(
input_ids, input_ids,
max_length=max_length, max_length=max_length,
...@@ -1463,6 +1485,7 @@ class GenerationIntegrationTests(unittest.TestCase): ...@@ -1463,6 +1485,7 @@ class GenerationIntegrationTests(unittest.TestCase):
# Sample # Sample
with self.assertWarns(UserWarning): with self.assertWarns(UserWarning):
with torch.no_grad():
bart_model.sample( bart_model.sample(
input_ids, input_ids,
max_length=max_length, max_length=max_length,
...@@ -1480,6 +1503,7 @@ class GenerationIntegrationTests(unittest.TestCase): ...@@ -1480,6 +1503,7 @@ class GenerationIntegrationTests(unittest.TestCase):
device=torch_device, device=torch_device,
) )
with self.assertWarns(UserWarning): with self.assertWarns(UserWarning):
with torch.no_grad():
bart_model.beam_search( bart_model.beam_search(
input_ids, input_ids,
num_beams=num_beams, num_beams=num_beams,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment