"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "f9ec5ca90b4dc08ddcc04c759542afead6b0d9e4"
Unverified Commit fafe9093 authored by Sanchit Gandhi's avatar Sanchit Gandhi Committed by GitHub
Browse files

[generate] deprecate forced ids processor (#29487)

* [generate] deprecate forced ids processor

* add todo

* make message clearer
parent 11bbb505
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
import inspect import inspect
import math import math
import warnings
from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union
import numpy as np import numpy as np
...@@ -1745,8 +1746,14 @@ class ForceTokensLogitsProcessor(LogitsProcessor): ...@@ -1745,8 +1746,14 @@ class ForceTokensLogitsProcessor(LogitsProcessor):
``` ```
""" """
def __init__(self, force_token_map: List[List[int]]): def __init__(self, force_token_map: List[List[int]], _has_warned: Optional[bool] = False):
self.force_token_map = dict(force_token_map) self.force_token_map = dict(force_token_map)
if not _has_warned:
# TODO(Sanchit): remove this processor entirely in v4.40
warnings.warn(
"This `ForceTokensLogitsProcessor` has been deprecated and will be removed in v4.40. Should you need to provide prompt ids for generation, specify `input_ids` to the generate method for decoder-only models, or `decoder_input_ids` for encoder-decoder models.",
FutureWarning,
)
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
......
...@@ -865,7 +865,12 @@ class GenerationMixin: ...@@ -865,7 +865,12 @@ class GenerationMixin:
SuppressTokensAtBeginLogitsProcessor(generation_config.begin_suppress_tokens, begin_index) SuppressTokensAtBeginLogitsProcessor(generation_config.begin_suppress_tokens, begin_index)
) )
if generation_config.forced_decoder_ids is not None: if generation_config.forced_decoder_ids is not None:
processors.append(ForceTokensLogitsProcessor(generation_config.forced_decoder_ids)) # TODO(Sanchit): deprecate in v4.40 by removing this logic
warnings.warn(
"You have explicitly specified `forced_decoder_ids`. This functionality has been deprecated and will throw an error in v4.40. Please remove the `forced_decoder_ids` argument in favour of `input_ids` or `decoder_input_ids` respectively.",
FutureWarning,
)
processors.append(ForceTokensLogitsProcessor(generation_config.forced_decoder_ids, _has_warned=True))
processors = self._merge_criteria_processor_list(processors, logits_processor) processors = self._merge_criteria_processor_list(processors, logits_processor)
# `LogitNormalization` should always be the last logit processor, when present # `LogitNormalization` should always be the last logit processor, when present
if generation_config.renormalize_logits is True: if generation_config.renormalize_logits is True:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment