"conftest.py" did not exist on "15d18e030722d5e51160fddc6f920939134e2e1e"
Unverified Commit fafe9093 authored by Sanchit Gandhi's avatar Sanchit Gandhi Committed by GitHub
Browse files

[generate] deprecate forced ids processor (#29487)

* [generate] deprecate forced ids processor

* add todo

* make message clearer
parent 11bbb505
......@@ -15,6 +15,7 @@
import inspect
import math
import warnings
from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union
import numpy as np
......@@ -1745,8 +1746,14 @@ class ForceTokensLogitsProcessor(LogitsProcessor):
```
"""
def __init__(self, force_token_map: List[List[int]]):
def __init__(self, force_token_map: List[List[int]], _has_warned: Optional[bool] = False):
self.force_token_map = dict(force_token_map)
if not _has_warned:
# TODO(Sanchit): remove this processor entirely in v4.40
warnings.warn(
"This `ForceTokensLogitsProcessor` has been deprecated and will be removed in v4.40. Should you need to provide prompt ids for generation, specify `input_ids` to the generate method for decoder-only models, or `decoder_input_ids` for encoder-decoder models.",
FutureWarning,
)
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
......
......@@ -865,7 +865,12 @@ class GenerationMixin:
SuppressTokensAtBeginLogitsProcessor(generation_config.begin_suppress_tokens, begin_index)
)
if generation_config.forced_decoder_ids is not None:
processors.append(ForceTokensLogitsProcessor(generation_config.forced_decoder_ids))
# TODO(Sanchit): deprecate in v4.40 by removing this logic
warnings.warn(
"You have explicitly specified `forced_decoder_ids`. This functionality has been deprecated and will throw an error in v4.40. Please remove the `forced_decoder_ids` argument in favour of `input_ids` or `decoder_input_ids` respectively.",
FutureWarning,
)
processors.append(ForceTokensLogitsProcessor(generation_config.forced_decoder_ids, _has_warned=True))
processors = self._merge_criteria_processor_list(processors, logits_processor)
# `LogitNormalization` should always be the last logit processor, when present
if generation_config.renormalize_logits is True:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment