Unverified Commit b8ac4d03 authored by Raushan Turganbay's avatar Raushan Turganbay Committed by GitHub
Browse files

Fix generation doctests (#30263)

* fix doctest

* fix torch doctest

* make CI happy

* raise error

* make fixup
parent 2ecefc39
...@@ -19,12 +19,12 @@ from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple ...@@ -19,12 +19,12 @@ from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple
import torch import torch
from ..cache_utils import DynamicCache from ..cache_utils import DynamicCache
from .logits_process import LogitsProcessorList, MinLengthLogitsProcessor
if TYPE_CHECKING: if TYPE_CHECKING:
from ..modeling_utils import PreTrainedModel from ..modeling_utils import PreTrainedModel
from .configuration_utils import GenerationConfig from .configuration_utils import GenerationConfig
from .logits_process import LogitsProcessorList
class CandidateGenerator: class CandidateGenerator:
...@@ -94,9 +94,9 @@ class AssistedCandidateGenerator(CandidateGenerator): ...@@ -94,9 +94,9 @@ class AssistedCandidateGenerator(CandidateGenerator):
input_ids: torch.LongTensor, input_ids: torch.LongTensor,
assistant_model: "PreTrainedModel", assistant_model: "PreTrainedModel",
generation_config: "GenerationConfig", generation_config: "GenerationConfig",
logits_processor: "LogitsProcessorList",
model_kwargs: Dict, model_kwargs: Dict,
inputs_tensor: Optional[torch.Tensor] = None, inputs_tensor: Optional[torch.Tensor] = None,
logits_processor: "LogitsProcessorList" = None,
): ):
# Make sure all data at the same device as assistant model # Make sure all data at the same device as assistant model
device = assistant_model.device device = assistant_model.device
...@@ -145,15 +145,22 @@ class AssistedCandidateGenerator(CandidateGenerator): ...@@ -145,15 +145,22 @@ class AssistedCandidateGenerator(CandidateGenerator):
self.input_ids_key = "input_ids" self.input_ids_key = "input_ids"
# Prepare generation-related options. # Prepare generation-related options.
self.logits_processor = logits_processor self.logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
self.generation_config = copy.deepcopy(generation_config) self.generation_config = copy.deepcopy(generation_config)
self.generation_config.return_dict_in_generate = True self.generation_config.return_dict_in_generate = True
self.generation_config.output_scores = True self.generation_config.output_scores = True
# avoid unnecessary warnings that min_length is larger than max_new_tokens # avoid unnecessary warnings that min_length is larger than max_new_tokens
# remove the `MinLengthLogitsProcessor` if exists (NOTE: no need to check for `MinNewTokensLogitsProcessor`)
self.main_model_min_length = self.generation_config.min_length self.main_model_min_length = self.generation_config.min_length
self.generation_config.min_length = 0 self.generation_config.min_length = 0
self.generation_config.min_new_tokens = None self.generation_config.min_new_tokens = None
for processor in self.logits_processor:
if type(processor) == MinLengthLogitsProcessor:
raise ValueError(
"Passing `MinLengthLogitsProcessor` when using `assisted_generation is disabled. "
"Please pass in `min_length` into `.generate()` instead"
)
def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, Optional[torch.FloatTensor]]: def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, Optional[torch.FloatTensor]]:
""" """
......
...@@ -528,9 +528,9 @@ class TFGenerationMixin: ...@@ -528,9 +528,9 @@ class TFGenerationMixin:
>>> for tok, score in zip(generated_tokens[0], transition_scores[0]): >>> for tok, score in zip(generated_tokens[0], transition_scores[0]):
... # | token | token string | logits | probability ... # | token | token string | logits | probability
... print(f"| {tok:5d} | {tokenizer.decode(tok):8s} | {score.numpy():.3f} | {np.exp(score.numpy()):.2%}") ... print(f"| {tok:5d} | {tokenizer.decode(tok):8s} | {score.numpy():.3f} | {np.exp(score.numpy()):.2%}")
| 262 | the | -1.413 | 24.33% | 262 | the | -1.414 | 24.33%
| 1110 | day | -2.609 | 7.36% | 1110 | day | -2.609 | 7.36%
| 618 | when | -2.009 | 13.41% | 618 | when | -2.010 | 13.40%
| 356 | we | -1.859 | 15.58% | 356 | we | -1.859 | 15.58%
| 460 | can | -2.508 | 8.14% | 460 | can | -2.508 | 8.14%
...@@ -549,7 +549,7 @@ class TFGenerationMixin: ...@@ -549,7 +549,7 @@ class TFGenerationMixin:
>>> # If you sum the generated tokens' scores and apply the length penalty, you'll get the sequence scores. >>> # If you sum the generated tokens' scores and apply the length penalty, you'll get the sequence scores.
>>> # Tip: recomputing the scores is only guaranteed to match with `normalize_logits=False`. Depending on the >>> # Tip: recomputing the scores is only guaranteed to match with `normalize_logits=False`. Depending on the
>>> # use case, you might want to recompute it with `normalize_logits=True`. >>> # use case, you might want to recompute it with `normalize_logits=True`.
>>> output_length = input_length + np.sum(transition_scores.numpy() < 0, axis=1) >>> output_length = np.sum(transition_scores.numpy() < 0, axis=1)
>>> length_penalty = model.generation_config.length_penalty >>> length_penalty = model.generation_config.length_penalty
>>> reconstructed_scores = np.sum(transition_scores, axis=1) / (output_length**length_penalty) >>> reconstructed_scores = np.sum(transition_scores, axis=1) / (output_length**length_penalty)
>>> print(np.allclose(outputs.sequences_scores, reconstructed_scores)) >>> print(np.allclose(outputs.sequences_scores, reconstructed_scores))
......
...@@ -705,9 +705,9 @@ class GenerationMixin: ...@@ -705,9 +705,9 @@ class GenerationMixin:
input_ids=input_ids, input_ids=input_ids,
assistant_model=assistant_model, assistant_model=assistant_model,
generation_config=generation_config, generation_config=generation_config,
logits_processor=logits_processor,
model_kwargs=model_kwargs, model_kwargs=model_kwargs,
inputs_tensor=inputs_tensor, inputs_tensor=inputs_tensor,
logits_processor=logits_processor,
) )
return candidate_generator return candidate_generator
...@@ -4601,24 +4601,18 @@ class GenerationMixin: ...@@ -4601,24 +4601,18 @@ class GenerationMixin:
>>> model.generation_config.pad_token_id = model.generation_config.eos_token_id >>> model.generation_config.pad_token_id = model.generation_config.eos_token_id
>>> input_prompt = "It might be possible to" >>> input_prompt = "It might be possible to"
>>> input_ids = tokenizer(input_prompt, return_tensors="pt").input_ids >>> input_ids = tokenizer(input_prompt, return_tensors="pt").input_ids
>>> # instantiate logits processors >>> model.generation_config.min_length = 10
>>> logits_processor = LogitsProcessorList(
... [
... MinLengthLogitsProcessor(10, eos_token_id=model.generation_config.eos_token_id),
... ]
... )
>>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=20)]) >>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=20)])
>>> candidate_generator = AssistedCandidateGenerator( >>> candidate_generator = AssistedCandidateGenerator(
... input_ids=input_ids, ... input_ids=input_ids,
... assistant_model=assistant_model, ... assistant_model=assistant_model,
... generation_config=model.generation_config, ... generation_config=model.generation_config,
... logits_processor=logits_processor,
... model_kwargs={}, ... model_kwargs={},
... ) ... )
>>> outputs = model._assisted_decoding( >>> outputs = model._assisted_decoding(
... input_ids, ... input_ids,
... candidate_generator=candidate_generator, ... candidate_generator=candidate_generator,
... logits_processor=logits_processor,
... stopping_criteria=stopping_criteria, ... stopping_criteria=stopping_criteria,
... ) ... )
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True) >>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment