Unverified Commit 430a04a7 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Docs: Update logit processors __call__ docs (#24729)

* tmp commit

* __call__ docs

* kwargs documented; shorter input_ids doc

* nit

* Update src/transformers/generation/logits_process.py
parent 6e2f0696
...@@ -30,17 +30,10 @@ logger = get_logger(__name__) ...@@ -30,17 +30,10 @@ logger = get_logger(__name__)
LOGITS_PROCESSOR_INPUTS_DOCSTRING = r""" LOGITS_PROCESSOR_INPUTS_DOCSTRING = r"""
Args: Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids)
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
scores (`torch.FloatTensor` of shape `(batch_size, config.vocab_size)`): scores (`torch.FloatTensor` of shape `(batch_size, config.vocab_size)`):
Prediction scores of a language modeling head. These can be logits for each vocabulary when not using beam Prediction scores of a language modeling head. These can be logits for each vocabulary when not using beam
search or log softmax for each vocabulary token when using beam search search or log softmax for each vocabulary token when using beam search
kwargs (`Dict[str, Any]`, *optional*):
Additional logits processor specific kwargs.
Return: Return:
`torch.FloatTensor` of shape `(batch_size, config.vocab_size)`: The processed prediction scores. `torch.FloatTensor` of shape `(batch_size, config.vocab_size)`: The processed prediction scores.
...@@ -53,7 +46,6 @@ class LogitsProcessor: ...@@ -53,7 +46,6 @@ class LogitsProcessor:
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
"""Torch method for processing logits."""
raise NotImplementedError( raise NotImplementedError(
f"{self.__class__} is an abstract class. Only classes inheriting this class can be called." f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
) )
...@@ -64,7 +56,6 @@ class LogitsWarper: ...@@ -64,7 +56,6 @@ class LogitsWarper:
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
"""Torch method for warping logits."""
raise NotImplementedError( raise NotImplementedError(
f"{self.__class__} is an abstract class. Only classes inheriting this class can be called." f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
) )
...@@ -77,8 +68,22 @@ class LogitsProcessorList(list): ...@@ -77,8 +68,22 @@ class LogitsProcessorList(list):
[`LogitsProcessor`] or [`LogitsWarper`] to the inputs. [`LogitsProcessor`] or [`LogitsWarper`] to the inputs.
""" """
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.FloatTensor: def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
r"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids)
scores (`torch.FloatTensor` of shape `(batch_size, config.vocab_size)`):
Prediction scores of a language modeling head. These can be logits for each vocabulary when not using
beam search or log softmax for each vocabulary token when using beam search
kwargs (`Dict[str, Any]`, *optional*):
Additional kwargs that are specific to a logits processor.
Return:
`torch.FloatTensor` of shape `(batch_size, config.vocab_size)`:
The processed prediction scores.
"""
for processor in self: for processor in self:
function_args = inspect.signature(processor.__call__).parameters function_args = inspect.signature(processor.__call__).parameters
if len(function_args) > 2: if len(function_args) > 2:
...@@ -116,6 +121,7 @@ class MinLengthLogitsProcessor(LogitsProcessor): ...@@ -116,6 +121,7 @@ class MinLengthLogitsProcessor(LogitsProcessor):
self.min_length = min_length self.min_length = min_length
self.eos_token_id = eos_token_id self.eos_token_id = eos_token_id
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
cur_len = input_ids.shape[-1] cur_len = input_ids.shape[-1]
if cur_len < self.min_length: if cur_len < self.min_length:
...@@ -154,6 +160,7 @@ class MinNewTokensLengthLogitsProcessor(LogitsProcessor): ...@@ -154,6 +160,7 @@ class MinNewTokensLengthLogitsProcessor(LogitsProcessor):
self.min_new_tokens = min_new_tokens self.min_new_tokens = min_new_tokens
self.eos_token_id = eos_token_id self.eos_token_id = eos_token_id
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
new_tokens_length = input_ids.shape[-1] - self.prompt_length_to_skip new_tokens_length = input_ids.shape[-1] - self.prompt_length_to_skip
if new_tokens_length < self.min_new_tokens: if new_tokens_length < self.min_new_tokens:
...@@ -178,7 +185,8 @@ class TemperatureLogitsWarper(LogitsWarper): ...@@ -178,7 +185,8 @@ class TemperatureLogitsWarper(LogitsWarper):
self.temperature = temperature self.temperature = temperature
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.FloatTensor: @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
scores = scores / self.temperature scores = scores / self.temperature
return scores return scores
...@@ -199,6 +207,7 @@ class RepetitionPenaltyLogitsProcessor(LogitsProcessor): ...@@ -199,6 +207,7 @@ class RepetitionPenaltyLogitsProcessor(LogitsProcessor):
self.penalty = penalty self.penalty = penalty
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
score = torch.gather(scores, 1, input_ids) score = torch.gather(scores, 1, input_ids)
...@@ -227,6 +236,7 @@ class EncoderRepetitionPenaltyLogitsProcessor(LogitsProcessor): ...@@ -227,6 +236,7 @@ class EncoderRepetitionPenaltyLogitsProcessor(LogitsProcessor):
self.penalty = 1 / penalty self.penalty = 1 / penalty
self.encoder_input_ids = encoder_input_ids self.encoder_input_ids = encoder_input_ids
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
score = torch.gather(scores, 1, self.encoder_input_ids) score = torch.gather(scores, 1, self.encoder_input_ids)
...@@ -262,6 +272,7 @@ class TopPLogitsWarper(LogitsWarper): ...@@ -262,6 +272,7 @@ class TopPLogitsWarper(LogitsWarper):
self.filter_value = filter_value self.filter_value = filter_value
self.min_tokens_to_keep = min_tokens_to_keep self.min_tokens_to_keep = min_tokens_to_keep
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
sorted_logits, sorted_indices = torch.sort(scores, descending=False) sorted_logits, sorted_indices = torch.sort(scores, descending=False)
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1) cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
...@@ -297,6 +308,7 @@ class TopKLogitsWarper(LogitsWarper): ...@@ -297,6 +308,7 @@ class TopKLogitsWarper(LogitsWarper):
self.top_k = max(top_k, min_tokens_to_keep) self.top_k = max(top_k, min_tokens_to_keep)
self.filter_value = filter_value self.filter_value = filter_value
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
top_k = min(self.top_k, scores.size(-1)) # Safety check top_k = min(self.top_k, scores.size(-1)) # Safety check
# Remove all tokens with a probability less than the last token of the top-k # Remove all tokens with a probability less than the last token of the top-k
...@@ -330,6 +342,7 @@ class TypicalLogitsWarper(LogitsWarper): ...@@ -330,6 +342,7 @@ class TypicalLogitsWarper(LogitsWarper):
self.mass = mass self.mass = mass
self.min_tokens_to_keep = min_tokens_to_keep self.min_tokens_to_keep = min_tokens_to_keep
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
# calculate entropy # calculate entropy
normalized = torch.nn.functional.log_softmax(scores, dim=-1) normalized = torch.nn.functional.log_softmax(scores, dim=-1)
...@@ -383,6 +396,7 @@ class EpsilonLogitsWarper(LogitsWarper): ...@@ -383,6 +396,7 @@ class EpsilonLogitsWarper(LogitsWarper):
self.filter_value = filter_value self.filter_value = filter_value
self.min_tokens_to_keep = min_tokens_to_keep self.min_tokens_to_keep = min_tokens_to_keep
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
# Determine which indices to remove # Determine which indices to remove
probabilities = scores.softmax(dim=-1) probabilities = scores.softmax(dim=-1)
...@@ -422,6 +436,7 @@ class EtaLogitsWarper(LogitsWarper): ...@@ -422,6 +436,7 @@ class EtaLogitsWarper(LogitsWarper):
self.filter_value = filter_value self.filter_value = filter_value
self.min_tokens_to_keep = min_tokens_to_keep self.min_tokens_to_keep = min_tokens_to_keep
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
# Calculate the adaptive cutoff # Calculate the adaptive cutoff
probabilities = scores.softmax(dim=-1) probabilities = scores.softmax(dim=-1)
...@@ -487,6 +502,7 @@ class NoRepeatNGramLogitsProcessor(LogitsProcessor): ...@@ -487,6 +502,7 @@ class NoRepeatNGramLogitsProcessor(LogitsProcessor):
raise ValueError(f"`ngram_size` has to be a strictly positive integer, but is {ngram_size}") raise ValueError(f"`ngram_size` has to be a strictly positive integer, but is {ngram_size}")
self.ngram_size = ngram_size self.ngram_size = ngram_size
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
num_batch_hypotheses = scores.shape[0] num_batch_hypotheses = scores.shape[0]
cur_len = input_ids.shape[-1] cur_len = input_ids.shape[-1]
...@@ -521,6 +537,7 @@ class EncoderNoRepeatNGramLogitsProcessor(LogitsProcessor): ...@@ -521,6 +537,7 @@ class EncoderNoRepeatNGramLogitsProcessor(LogitsProcessor):
self.batch_size = encoder_input_ids.shape[0] self.batch_size = encoder_input_ids.shape[0]
self.generated_ngrams = _get_ngrams(encoder_ngram_size, encoder_input_ids, self.batch_size) self.generated_ngrams = _get_ngrams(encoder_ngram_size, encoder_input_ids, self.batch_size)
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
# B x num_beams # B x num_beams
num_hypos = scores.shape[0] num_hypos = scores.shape[0]
...@@ -612,6 +629,7 @@ class SequenceBiasLogitsProcessor(LogitsProcessor): ...@@ -612,6 +629,7 @@ class SequenceBiasLogitsProcessor(LogitsProcessor):
self.length_greather_than_1_bias = None self.length_greather_than_1_bias = None
self.prepared_bias_variables = False self.prepared_bias_variables = False
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
# 1 - Prepares the bias tensors. This is only needed the first time the logit processor is called. # 1 - Prepares the bias tensors. This is only needed the first time the logit processor is called.
if not self.prepared_bias_variables: if not self.prepared_bias_variables:
...@@ -774,6 +792,7 @@ class PrefixConstrainedLogitsProcessor(LogitsProcessor): ...@@ -774,6 +792,7 @@ class PrefixConstrainedLogitsProcessor(LogitsProcessor):
self._prefix_allowed_tokens_fn = prefix_allowed_tokens_fn self._prefix_allowed_tokens_fn = prefix_allowed_tokens_fn
self._num_beams = num_beams self._num_beams = num_beams
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
mask = torch.full_like(scores, -math.inf) mask = torch.full_like(scores, -math.inf)
for batch_id, beam_sent in enumerate(input_ids.view(-1, self._num_beams, input_ids.shape[-1])): for batch_id, beam_sent in enumerate(input_ids.view(-1, self._num_beams, input_ids.shape[-1])):
...@@ -821,6 +840,23 @@ class HammingDiversityLogitsProcessor(LogitsProcessor): ...@@ -821,6 +840,23 @@ class HammingDiversityLogitsProcessor(LogitsProcessor):
current_tokens: torch.LongTensor, current_tokens: torch.LongTensor,
beam_group_idx: int, beam_group_idx: int,
) -> torch.FloatTensor: ) -> torch.FloatTensor:
r"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids)
scores (`torch.FloatTensor` of shape `(batch_size, config.vocab_size)`):
Prediction scores of a language modeling head. These can be logits for each vocabulary when not using
beam search or log softmax for each vocabulary token when using beam search
current_tokens (`torch.LongTensor` of shape `(batch_size)`):
Indices of input sequence tokens in the vocabulary, corresponding to the tokens selected by the other
beam groups in the current generation step.
beam_group_idx (`int`):
The index of the beam group currently being processed.
Return:
`torch.FloatTensor` of shape `(batch_size, config.vocab_size)`:
The processed prediction scores.
"""
# hamming diversity: penalise using same token in current group which was used in previous groups at # hamming diversity: penalise using same token in current group which was used in previous groups at
# the same time step # the same time step
batch_size = current_tokens.shape[0] // self._num_beams batch_size = current_tokens.shape[0] // self._num_beams
...@@ -855,6 +891,7 @@ class ForcedBOSTokenLogitsProcessor(LogitsProcessor): ...@@ -855,6 +891,7 @@ class ForcedBOSTokenLogitsProcessor(LogitsProcessor):
def __init__(self, bos_token_id: int): def __init__(self, bos_token_id: int):
self.bos_token_id = bos_token_id self.bos_token_id = bos_token_id
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
cur_len = input_ids.shape[-1] cur_len = input_ids.shape[-1]
if cur_len == 1: if cur_len == 1:
...@@ -882,6 +919,7 @@ class ForcedEOSTokenLogitsProcessor(LogitsProcessor): ...@@ -882,6 +919,7 @@ class ForcedEOSTokenLogitsProcessor(LogitsProcessor):
eos_token_id = [eos_token_id] eos_token_id = [eos_token_id]
self.eos_token_id = eos_token_id self.eos_token_id = eos_token_id
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
cur_len = input_ids.shape[-1] cur_len = input_ids.shape[-1]
if cur_len == self.max_length - 1: if cur_len == self.max_length - 1:
...@@ -898,6 +936,7 @@ class InfNanRemoveLogitsProcessor(LogitsProcessor): ...@@ -898,6 +936,7 @@ class InfNanRemoveLogitsProcessor(LogitsProcessor):
the logits processor should only be used if necessary since it can slow down the generation method. the logits processor should only be used if necessary since it can slow down the generation method.
""" """
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
# set all nan values to 0.0 # set all nan values to 0.0
scores[scores != scores] = 0.0 scores[scores != scores] = 0.0
...@@ -935,7 +974,8 @@ class ExponentialDecayLengthPenalty(LogitsProcessor): ...@@ -935,7 +974,8 @@ class ExponentialDecayLengthPenalty(LogitsProcessor):
eos_token_id = [eos_token_id] eos_token_id = [eos_token_id]
self.eos_token_id = eos_token_id self.eos_token_id = eos_token_id
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.FloatTensor: @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
cur_len = input_ids.shape[-1] cur_len = input_ids.shape[-1]
if cur_len > self.regulation_start: if cur_len > self.regulation_start:
for i in self.eos_token_id: for i in self.eos_token_id:
...@@ -951,7 +991,8 @@ class LogitNormalization(LogitsProcessor, LogitsWarper): ...@@ -951,7 +991,8 @@ class LogitNormalization(LogitsProcessor, LogitsWarper):
the scores are normalized when comparing the hypotheses. the scores are normalized when comparing the hypotheses.
""" """
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor: @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
scores = scores.log_softmax(dim=-1) scores = scores.log_softmax(dim=-1)
return scores return scores
...@@ -967,7 +1008,8 @@ class SuppressTokensAtBeginLogitsProcessor(LogitsProcessor): ...@@ -967,7 +1008,8 @@ class SuppressTokensAtBeginLogitsProcessor(LogitsProcessor):
self.begin_suppress_tokens = list(begin_suppress_tokens) self.begin_suppress_tokens = list(begin_suppress_tokens)
self.begin_index = begin_index self.begin_index = begin_index
def __call__(self, input_ids, scores): @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
if input_ids.shape[1] == self.begin_index: if input_ids.shape[1] == self.begin_index:
scores[:, self.begin_suppress_tokens] = -float("inf") scores[:, self.begin_suppress_tokens] = -float("inf")
...@@ -981,7 +1023,8 @@ class SuppressTokensLogitsProcessor(LogitsProcessor): ...@@ -981,7 +1023,8 @@ class SuppressTokensLogitsProcessor(LogitsProcessor):
def __init__(self, suppress_tokens): def __init__(self, suppress_tokens):
self.suppress_tokens = list(suppress_tokens) self.suppress_tokens = list(suppress_tokens)
def __call__(self, input_ids, scores): @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
scores[:, self.suppress_tokens] = -float("inf") scores[:, self.suppress_tokens] = -float("inf")
return scores return scores
...@@ -994,7 +1037,8 @@ class ForceTokensLogitsProcessor(LogitsProcessor): ...@@ -994,7 +1037,8 @@ class ForceTokensLogitsProcessor(LogitsProcessor):
def __init__(self, force_token_map: List[List[int]]): def __init__(self, force_token_map: List[List[int]]):
self.force_token_map = dict(force_token_map) self.force_token_map = dict(force_token_map)
def __call__(self, input_ids, scores): @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
generation_idx = input_ids.shape[-1] generation_idx = input_ids.shape[-1]
current_token = self.force_token_map.get(generation_idx, None) current_token = self.force_token_map.get(generation_idx, None)
if current_token is not None: if current_token is not None:
...@@ -1030,7 +1074,8 @@ class WhisperTimeStampLogitsProcessor(LogitsProcessor): ...@@ -1030,7 +1074,8 @@ class WhisperTimeStampLogitsProcessor(LogitsProcessor):
self.begin_index -= 1 self.begin_index -= 1
self.max_initial_timestamp_index = generate_config.max_initial_timestamp_index self.max_initial_timestamp_index = generate_config.max_initial_timestamp_index
def __call__(self, input_ids, scores): @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
# suppress <|notimestamps|> which is handled by without_timestamps # suppress <|notimestamps|> which is handled by without_timestamps
scores[:, self.no_timestamps_token_id] = -float("inf") scores[:, self.no_timestamps_token_id] = -float("inf")
...@@ -1089,7 +1134,8 @@ class ClassifierFreeGuidanceLogitsProcessor(LogitsProcessor): ...@@ -1089,7 +1134,8 @@ class ClassifierFreeGuidanceLogitsProcessor(LogitsProcessor):
f"{guidance_scale}." f"{guidance_scale}."
) )
def __call__(self, input_ids, scores): @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
# simple check to make sure we have compatible batch sizes between our # simple check to make sure we have compatible batch sizes between our
# logits scores (cond + uncond) and input ids (cond only) # logits scores (cond + uncond) and input ids (cond only)
if scores.shape[0] != 2 * input_ids.shape[0]: if scores.shape[0] != 2 * input_ids.shape[0]:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment