"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "010e0460b22ddd7f74e31163f69ab3da2e9741ba"
Unverified Commit 83238eee authored by Raushan Turganbay's avatar Raushan Turganbay Committed by GitHub
Browse files

Pass device in Logits Processor's init (#29804)



* add device in logits processor

* remove device when not needed

* codestyle

* tests

* forgot `melody` version

* Update src/transformers/models/whisper/generation_whisper.py
Co-authored-by: default avatarJoao Gante <joaofranciscocardosogante@gmail.com>

* codestyle

* updates

---------
Co-authored-by: default avatarJoao Gante <joaofranciscocardosogante@gmail.com>
parent c73ee133
...@@ -110,6 +110,8 @@ class MinLengthLogitsProcessor(LogitsProcessor): ...@@ -110,6 +110,8 @@ class MinLengthLogitsProcessor(LogitsProcessor):
The minimum length below which the score of `eos_token_id` is set to `-float("Inf")`. The minimum length below which the score of `eos_token_id` is set to `-float("Inf")`.
eos_token_id (`Union[int, List[int], torch.Tensor]`): eos_token_id (`Union[int, List[int], torch.Tensor]`):
The id(s) of the *end-of-sequence* token. The id(s) of the *end-of-sequence* token.
device (`str`, *optional*, defaults to `"cpu"`):
The device to allocate the tensors.
Examples: Examples:
...@@ -137,14 +139,14 @@ class MinLengthLogitsProcessor(LogitsProcessor): ...@@ -137,14 +139,14 @@ class MinLengthLogitsProcessor(LogitsProcessor):
``` ```
""" """
def __init__(self, min_length: int, eos_token_id: Union[int, List[int], torch.Tensor]): def __init__(self, min_length: int, eos_token_id: Union[int, List[int], torch.Tensor], device: str = "cpu"):
if not isinstance(min_length, int) or min_length < 0: if not isinstance(min_length, int) or min_length < 0:
raise ValueError(f"`min_length` has to be a non-negative integer, but is {min_length}") raise ValueError(f"`min_length` has to be a non-negative integer, but is {min_length}")
if not isinstance(eos_token_id, torch.Tensor): if not isinstance(eos_token_id, torch.Tensor):
if isinstance(eos_token_id, int): if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id] eos_token_id = [eos_token_id]
eos_token_id = torch.tensor(eos_token_id) eos_token_id = torch.tensor(eos_token_id, device=device)
self.min_length = min_length self.min_length = min_length
self.eos_token_id = eos_token_id self.eos_token_id = eos_token_id
...@@ -152,7 +154,6 @@ class MinLengthLogitsProcessor(LogitsProcessor): ...@@ -152,7 +154,6 @@ class MinLengthLogitsProcessor(LogitsProcessor):
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
vocab_tensor = torch.arange(scores.shape[-1], device=scores.device) vocab_tensor = torch.arange(scores.shape[-1], device=scores.device)
self.eos_token_id = self.eos_token_id.to(scores.device)
eos_token_mask = torch.isin(vocab_tensor, self.eos_token_id) eos_token_mask = torch.isin(vocab_tensor, self.eos_token_id)
scores_processed = scores.clone() scores_processed = scores.clone()
if input_ids.shape[-1] < self.min_length: if input_ids.shape[-1] < self.min_length:
...@@ -173,6 +174,8 @@ class MinNewTokensLengthLogitsProcessor(LogitsProcessor): ...@@ -173,6 +174,8 @@ class MinNewTokensLengthLogitsProcessor(LogitsProcessor):
The minimum *new* tokens length below which the score of `eos_token_id` is set to `-float("Inf")`. The minimum *new* tokens length below which the score of `eos_token_id` is set to `-float("Inf")`.
eos_token_id (`Union[int, List[int], torch.Tensor]`): eos_token_id (`Union[int, List[int], torch.Tensor]`):
The id(s) of the *end-of-sequence* token. The id(s) of the *end-of-sequence* token.
device (`str`, *optional*, defaults to `"cpu"`):
The device to allocate the tensors.
Examples: Examples:
...@@ -196,7 +199,11 @@ class MinNewTokensLengthLogitsProcessor(LogitsProcessor): ...@@ -196,7 +199,11 @@ class MinNewTokensLengthLogitsProcessor(LogitsProcessor):
""" """
def __init__( def __init__(
self, prompt_length_to_skip: int, min_new_tokens: int, eos_token_id: Union[int, List[int], torch.Tensor] self,
prompt_length_to_skip: int,
min_new_tokens: int,
eos_token_id: Union[int, List[int], torch.Tensor],
device: str = "cpu",
): ):
for arg_name, arg_value in [ for arg_name, arg_value in [
("prompt_length_to_skip", prompt_length_to_skip), ("prompt_length_to_skip", prompt_length_to_skip),
...@@ -208,7 +215,7 @@ class MinNewTokensLengthLogitsProcessor(LogitsProcessor): ...@@ -208,7 +215,7 @@ class MinNewTokensLengthLogitsProcessor(LogitsProcessor):
if not isinstance(eos_token_id, torch.Tensor): if not isinstance(eos_token_id, torch.Tensor):
if isinstance(eos_token_id, int): if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id] eos_token_id = [eos_token_id]
eos_token_id = torch.tensor(eos_token_id) eos_token_id = torch.tensor(eos_token_id, device=device)
self.prompt_length_to_skip = prompt_length_to_skip self.prompt_length_to_skip = prompt_length_to_skip
self.min_new_tokens = min_new_tokens self.min_new_tokens = min_new_tokens
...@@ -219,7 +226,6 @@ class MinNewTokensLengthLogitsProcessor(LogitsProcessor): ...@@ -219,7 +226,6 @@ class MinNewTokensLengthLogitsProcessor(LogitsProcessor):
new_tokens_length = input_ids.shape[-1] - self.prompt_length_to_skip new_tokens_length = input_ids.shape[-1] - self.prompt_length_to_skip
scores_processed = scores.clone() scores_processed = scores.clone()
vocab_tensor = torch.arange(scores.shape[-1], device=scores.device) vocab_tensor = torch.arange(scores.shape[-1], device=scores.device)
self.eos_token_id = self.eos_token_id.to(scores.device)
eos_token_mask = torch.isin(vocab_tensor, self.eos_token_id) eos_token_mask = torch.isin(vocab_tensor, self.eos_token_id)
if new_tokens_length < self.min_new_tokens: if new_tokens_length < self.min_new_tokens:
scores_processed = torch.where(eos_token_mask, -math.inf, scores) scores_processed = torch.where(eos_token_mask, -math.inf, scores)
...@@ -779,6 +785,8 @@ class EtaLogitsWarper(LogitsWarper): ...@@ -779,6 +785,8 @@ class EtaLogitsWarper(LogitsWarper):
Specifies the minimum number of tokens that must be kept for generation, regardless of their probabilities. Specifies the minimum number of tokens that must be kept for generation, regardless of their probabilities.
For example, if `min_tokens_to_keep` is set to 1, at least one token will always be kept for generation, For example, if `min_tokens_to_keep` is set to 1, at least one token will always be kept for generation,
even if all tokens have probabilities below the cutoff `eta`. even if all tokens have probabilities below the cutoff `eta`.
device (`str`, *optional*, defaults to `"cpu"`):
The device to allocate the tensors.
Examples: Examples:
```python ```python
...@@ -806,7 +814,9 @@ class EtaLogitsWarper(LogitsWarper): ...@@ -806,7 +814,9 @@ class EtaLogitsWarper(LogitsWarper):
``` ```
""" """
def __init__(self, epsilon: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1): def __init__(
self, epsilon: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1, device: str = "cpu"
):
epsilon = float(epsilon) epsilon = float(epsilon)
if epsilon <= 0 or epsilon >= 1: if epsilon <= 0 or epsilon >= 1:
raise ValueError(f"`eta_cutoff` has to be a float > 0 and < 1, but is {epsilon}") raise ValueError(f"`eta_cutoff` has to be a float > 0 and < 1, but is {epsilon}")
...@@ -817,13 +827,12 @@ class EtaLogitsWarper(LogitsWarper): ...@@ -817,13 +827,12 @@ class EtaLogitsWarper(LogitsWarper):
f"`min_tokens_to_keep` has to be a strictly positive integer, but is {min_tokens_to_keep}" f"`min_tokens_to_keep` has to be a strictly positive integer, but is {min_tokens_to_keep}"
) )
self.epsilon = torch.tensor(epsilon) self.epsilon = torch.tensor(epsilon, device=device)
self.filter_value = filter_value self.filter_value = filter_value
self.min_tokens_to_keep = min_tokens_to_keep self.min_tokens_to_keep = min_tokens_to_keep
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
# Calculate the adaptive cutoff
probabilities = scores.softmax(dim=-1) probabilities = scores.softmax(dim=-1)
entropy = torch.distributions.Categorical(logits=scores).entropy() entropy = torch.distributions.Categorical(logits=scores).entropy()
eta = torch.min(self.epsilon, torch.sqrt(self.epsilon) * torch.exp(-entropy))[..., None] eta = torch.min(self.epsilon, torch.sqrt(self.epsilon) * torch.exp(-entropy))[..., None]
...@@ -1530,6 +1539,8 @@ class ForcedEOSTokenLogitsProcessor(LogitsProcessor): ...@@ -1530,6 +1539,8 @@ class ForcedEOSTokenLogitsProcessor(LogitsProcessor):
The maximum length of the sequence to be generated. The maximum length of the sequence to be generated.
eos_token_id (`Union[int, List[int], torch.Tensor]`): eos_token_id (`Union[int, List[int], torch.Tensor]`):
The id(s) of the *end-of-sequence* token. The id(s) of the *end-of-sequence* token.
device (`str`, *optional*, defaults to `"cpu"`):
The device to allocate the tensors.
Examples: Examples:
...@@ -1553,13 +1564,13 @@ class ForcedEOSTokenLogitsProcessor(LogitsProcessor): ...@@ -1553,13 +1564,13 @@ class ForcedEOSTokenLogitsProcessor(LogitsProcessor):
``` ```
""" """
def __init__(self, max_length: int, eos_token_id: Union[int, List[int], torch.Tensor]): def __init__(self, max_length: int, eos_token_id: Union[int, List[int], torch.Tensor], device: str = "cpu"):
self.max_length = max_length self.max_length = max_length
if not isinstance(eos_token_id, torch.Tensor): if not isinstance(eos_token_id, torch.Tensor):
if isinstance(eos_token_id, int): if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id] eos_token_id = [eos_token_id]
eos_token_id = torch.tensor(eos_token_id) eos_token_id = torch.tensor(eos_token_id, device=device)
self.eos_token_id = eos_token_id self.eos_token_id = eos_token_id
if torch.is_floating_point(eos_token_id) or (eos_token_id < 0).any(): if torch.is_floating_point(eos_token_id) or (eos_token_id < 0).any():
...@@ -1568,7 +1579,6 @@ class ForcedEOSTokenLogitsProcessor(LogitsProcessor): ...@@ -1568,7 +1579,6 @@ class ForcedEOSTokenLogitsProcessor(LogitsProcessor):
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
cur_len = input_ids.shape[-1] cur_len = input_ids.shape[-1]
self.eos_token_id = self.eos_token_id.to(scores.device)
scores_processed = scores scores_processed = scores
if cur_len == self.max_length - 1: if cur_len == self.max_length - 1:
scores_processed = torch.full_like(scores, -math.inf) scores_processed = torch.full_like(scores, -math.inf)
...@@ -1770,8 +1780,8 @@ class SuppressTokensAtBeginLogitsProcessor(LogitsProcessor): ...@@ -1770,8 +1780,8 @@ class SuppressTokensAtBeginLogitsProcessor(LogitsProcessor):
``` ```
""" """
def __init__(self, begin_suppress_tokens, begin_index): def __init__(self, begin_suppress_tokens, begin_index, device: str = "cpu"):
self.begin_suppress_tokens = torch.tensor(list(begin_suppress_tokens)) self.begin_suppress_tokens = torch.tensor(list(begin_suppress_tokens), device=device)
self.begin_index = begin_index self.begin_index = begin_index
def set_begin_index(self, begin_index): def set_begin_index(self, begin_index):
...@@ -1780,7 +1790,6 @@ class SuppressTokensAtBeginLogitsProcessor(LogitsProcessor): ...@@ -1780,7 +1790,6 @@ class SuppressTokensAtBeginLogitsProcessor(LogitsProcessor):
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
vocab_tensor = torch.arange(scores.shape[-1], device=scores.device) vocab_tensor = torch.arange(scores.shape[-1], device=scores.device)
self.begin_suppress_tokens = self.begin_suppress_tokens.to(scores.device)
suppress_token_mask = torch.isin(vocab_tensor, self.begin_suppress_tokens) suppress_token_mask = torch.isin(vocab_tensor, self.begin_suppress_tokens)
scores_processed = scores scores_processed = scores
if input_ids.shape[-1] == self.begin_index: if input_ids.shape[-1] == self.begin_index:
...@@ -1818,13 +1827,12 @@ class SuppressTokensLogitsProcessor(LogitsProcessor): ...@@ -1818,13 +1827,12 @@ class SuppressTokensLogitsProcessor(LogitsProcessor):
``` ```
""" """
def __init__(self, suppress_tokens): def __init__(self, suppress_tokens, device: str = "cpu"):
self.suppress_tokens = torch.tensor(list(suppress_tokens)) self.suppress_tokens = torch.tensor(list(suppress_tokens), device=device)
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
vocab_tensor = torch.arange(scores.shape[-1], device=scores.device) vocab_tensor = torch.arange(scores.shape[-1], device=scores.device)
self.suppress_tokens = self.suppress_tokens.to(scores.device)
suppress_token_mask = torch.isin(vocab_tensor, self.suppress_tokens) suppress_token_mask = torch.isin(vocab_tensor, self.suppress_tokens)
scores = torch.where(suppress_token_mask, -float("inf"), scores) scores = torch.where(suppress_token_mask, -float("inf"), scores)
return scores return scores
...@@ -1915,7 +1923,10 @@ class WhisperTimeStampLogitsProcessor(LogitsProcessor): ...@@ -1915,7 +1923,10 @@ class WhisperTimeStampLogitsProcessor(LogitsProcessor):
""" """
def __init__( def __init__(
self, generate_config, begin_index: Optional[int] = None, _detect_timestamp_from_logprob: Optional[bool] = None self,
generate_config,
begin_index: Optional[int] = None,
_detect_timestamp_from_logprob: Optional[bool] = None,
): # support for the kwargs ): # support for the kwargs
self.no_timestamps_token_id = generate_config.no_timestamps_token_id self.no_timestamps_token_id = generate_config.no_timestamps_token_id
self.timestamp_begin = generate_config.no_timestamps_token_id + 1 self.timestamp_begin = generate_config.no_timestamps_token_id + 1
...@@ -2292,11 +2303,11 @@ class BarkEosPrioritizerLogitsProcessor(LogitsProcessor): ...@@ -2292,11 +2303,11 @@ class BarkEosPrioritizerLogitsProcessor(LogitsProcessor):
Minimum end of speech threshold. Minimum end of speech threshold.
""" """
def __init__(self, eos_token_id: Union[int, List[int], torch.Tensor], min_eos_p: float): def __init__(self, eos_token_id: Union[int, List[int], torch.Tensor], min_eos_p: float, device: str = "cpu"):
if not isinstance(eos_token_id, torch.Tensor): if not isinstance(eos_token_id, torch.Tensor):
if isinstance(eos_token_id, int): if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id] eos_token_id = [eos_token_id]
eos_token_id = torch.tensor(eos_token_id) eos_token_id = torch.tensor(eos_token_id, device=device)
self.eos_token_id = eos_token_id self.eos_token_id = eos_token_id
if torch.is_floating_point(eos_token_id) or (eos_token_id < 0).any(): if torch.is_floating_point(eos_token_id) or (eos_token_id < 0).any():
...@@ -2309,7 +2320,6 @@ class BarkEosPrioritizerLogitsProcessor(LogitsProcessor): ...@@ -2309,7 +2320,6 @@ class BarkEosPrioritizerLogitsProcessor(LogitsProcessor):
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
scores_processed = scores scores_processed = scores
self.eos_token_id = self.eos_token_id.to(scores.device)
if self.min_eos_p: if self.min_eos_p:
probs = torch.nn.functional.softmax(scores.float(), dim=-1) probs = torch.nn.functional.softmax(scores.float(), dim=-1)
# create scores full of -inf except for the eos_token_id # create scores full of -inf except for the eos_token_id
......
...@@ -723,6 +723,7 @@ class GenerationMixin: ...@@ -723,6 +723,7 @@ class GenerationMixin:
def _get_logits_warper( def _get_logits_warper(
self, self,
generation_config: GenerationConfig, generation_config: GenerationConfig,
device: str,
) -> LogitsProcessorList: ) -> LogitsProcessorList:
""" """
This class returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsWarper`] instances This class returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsWarper`] instances
...@@ -765,7 +766,9 @@ class GenerationMixin: ...@@ -765,7 +766,9 @@ class GenerationMixin:
) )
if generation_config.eta_cutoff is not None and 0.0 < generation_config.eta_cutoff < 1.0: if generation_config.eta_cutoff is not None and 0.0 < generation_config.eta_cutoff < 1.0:
warpers.append( warpers.append(
EtaLogitsWarper(epsilon=generation_config.eta_cutoff, min_tokens_to_keep=min_tokens_to_keep) EtaLogitsWarper(
epsilon=generation_config.eta_cutoff, min_tokens_to_keep=min_tokens_to_keep, device=device
)
) )
# `LogitNormalization` should always be the last logit processor, when present # `LogitNormalization` should always be the last logit processor, when present
if generation_config.renormalize_logits is True: if generation_config.renormalize_logits is True:
...@@ -818,7 +821,8 @@ class GenerationMixin: ...@@ -818,7 +821,8 @@ class GenerationMixin:
): ):
processors.append( processors.append(
EncoderRepetitionPenaltyLogitsProcessor( EncoderRepetitionPenaltyLogitsProcessor(
penalty=generation_config.encoder_repetition_penalty, encoder_input_ids=encoder_input_ids penalty=generation_config.encoder_repetition_penalty,
encoder_input_ids=encoder_input_ids,
) )
) )
if generation_config.repetition_penalty is not None and generation_config.repetition_penalty != 1.0: if generation_config.repetition_penalty is not None and generation_config.repetition_penalty != 1.0:
...@@ -830,18 +834,30 @@ class GenerationMixin: ...@@ -830,18 +834,30 @@ class GenerationMixin:
and generation_config.encoder_no_repeat_ngram_size > 0 and generation_config.encoder_no_repeat_ngram_size > 0
): ):
processors.append( processors.append(
EncoderNoRepeatNGramLogitsProcessor(generation_config.encoder_no_repeat_ngram_size, encoder_input_ids) EncoderNoRepeatNGramLogitsProcessor(
generation_config.encoder_no_repeat_ngram_size,
encoder_input_ids,
)
) )
if generation_config.bad_words_ids is not None: if generation_config.bad_words_ids is not None:
processors.append( processors.append(
NoBadWordsLogitsProcessor(generation_config.bad_words_ids, generation_config.eos_token_id) NoBadWordsLogitsProcessor(
generation_config.bad_words_ids,
generation_config.eos_token_id,
)
) )
if ( if (
generation_config.min_length is not None generation_config.min_length is not None
and generation_config.eos_token_id is not None and generation_config.eos_token_id is not None
and generation_config.min_length > 0 and generation_config.min_length > 0
): ):
processors.append(MinLengthLogitsProcessor(generation_config.min_length, generation_config.eos_token_id)) processors.append(
MinLengthLogitsProcessor(
generation_config.min_length,
generation_config.eos_token_id,
device=device,
)
)
if ( if (
generation_config.min_new_tokens is not None generation_config.min_new_tokens is not None
and generation_config.eos_token_id is not None and generation_config.eos_token_id is not None
...@@ -849,20 +865,32 @@ class GenerationMixin: ...@@ -849,20 +865,32 @@ class GenerationMixin:
): ):
processors.append( processors.append(
MinNewTokensLengthLogitsProcessor( MinNewTokensLengthLogitsProcessor(
input_ids_seq_length, generation_config.min_new_tokens, generation_config.eos_token_id input_ids_seq_length,
generation_config.min_new_tokens,
generation_config.eos_token_id,
device=device,
) )
) )
if prefix_allowed_tokens_fn is not None: if prefix_allowed_tokens_fn is not None:
processors.append( processors.append(
PrefixConstrainedLogitsProcessor( PrefixConstrainedLogitsProcessor(
prefix_allowed_tokens_fn, generation_config.num_beams // generation_config.num_beam_groups prefix_allowed_tokens_fn,
generation_config.num_beams // generation_config.num_beam_groups,
) )
) )
if generation_config.forced_bos_token_id is not None: if generation_config.forced_bos_token_id is not None:
processors.append(ForcedBOSTokenLogitsProcessor(generation_config.forced_bos_token_id)) processors.append(
ForcedBOSTokenLogitsProcessor(
generation_config.forced_bos_token_id,
)
)
if generation_config.forced_eos_token_id is not None: if generation_config.forced_eos_token_id is not None:
processors.append( processors.append(
ForcedEOSTokenLogitsProcessor(generation_config.max_length, generation_config.forced_eos_token_id) ForcedEOSTokenLogitsProcessor(
generation_config.max_length,
generation_config.forced_eos_token_id,
device=device,
)
) )
if generation_config.remove_invalid_values is True: if generation_config.remove_invalid_values is True:
processors.append(InfNanRemoveLogitsProcessor()) processors.append(InfNanRemoveLogitsProcessor())
...@@ -875,7 +903,12 @@ class GenerationMixin: ...@@ -875,7 +903,12 @@ class GenerationMixin:
) )
) )
if generation_config.suppress_tokens is not None: if generation_config.suppress_tokens is not None:
processors.append(SuppressTokensLogitsProcessor(generation_config.suppress_tokens)) processors.append(
SuppressTokensLogitsProcessor(
generation_config.suppress_tokens,
device=device,
)
)
if generation_config.begin_suppress_tokens is not None: if generation_config.begin_suppress_tokens is not None:
begin_index = input_ids_seq_length begin_index = input_ids_seq_length
begin_index = ( begin_index = (
...@@ -887,7 +920,11 @@ class GenerationMixin: ...@@ -887,7 +920,11 @@ class GenerationMixin:
# generation starts after the last token that is forced # generation starts after the last token that is forced
begin_index += generation_config.forced_decoder_ids[-1][0] begin_index += generation_config.forced_decoder_ids[-1][0]
processors.append( processors.append(
SuppressTokensAtBeginLogitsProcessor(generation_config.begin_suppress_tokens, begin_index) SuppressTokensAtBeginLogitsProcessor(
generation_config.begin_suppress_tokens,
begin_index,
device=device,
)
) )
if generation_config.forced_decoder_ids is not None: if generation_config.forced_decoder_ids is not None:
# TODO(Sanchit): deprecate in v4.40 by removing this logic # TODO(Sanchit): deprecate in v4.40 by removing this logic
...@@ -1779,7 +1816,12 @@ class GenerationMixin: ...@@ -1779,7 +1816,12 @@ class GenerationMixin:
# 12. prepare logits warper (if `do_sample` is `True`) # 12. prepare logits warper (if `do_sample` is `True`)
prepared_logits_warper = ( prepared_logits_warper = (
self._get_logits_warper(generation_config) if generation_config.do_sample else None self._get_logits_warper(
generation_config,
device=input_ids.device,
)
if generation_config.do_sample
else None
) )
# 13. run assisted generate # 13. run assisted generate
...@@ -1812,7 +1854,9 @@ class GenerationMixin: ...@@ -1812,7 +1854,9 @@ class GenerationMixin:
elif generation_mode in (GenerationMode.SAMPLE, GenerationMode.GREEDY_SEARCH): elif generation_mode in (GenerationMode.SAMPLE, GenerationMode.GREEDY_SEARCH):
# 11. prepare logits warper # 11. prepare logits warper
prepared_logits_warper = ( prepared_logits_warper = (
self._get_logits_warper(generation_config) if generation_config.do_sample else None self._get_logits_warper(generation_config, device=input_ids.device)
if generation_config.do_sample
else None
) )
# 12. expand input_ids with `num_return_sequences` additional sequences per batch # 12. expand input_ids with `num_return_sequences` additional sequences per batch
...@@ -1838,7 +1882,9 @@ class GenerationMixin: ...@@ -1838,7 +1882,9 @@ class GenerationMixin:
elif generation_mode in (GenerationMode.BEAM_SAMPLE, GenerationMode.BEAM_SEARCH): elif generation_mode in (GenerationMode.BEAM_SAMPLE, GenerationMode.BEAM_SEARCH):
# 11. prepare logits warper # 11. prepare logits warper
prepared_logits_warper = ( prepared_logits_warper = (
self._get_logits_warper(generation_config) if generation_config.do_sample else None self._get_logits_warper(generation_config, device=input_ids.device)
if generation_config.do_sample
else None
) )
# 12. prepare beam search scorer # 12. prepare beam search scorer
......
...@@ -1729,6 +1729,7 @@ class MusicgenForCausalLM(MusicgenPreTrainedModel): ...@@ -1729,6 +1729,7 @@ class MusicgenForCausalLM(MusicgenPreTrainedModel):
encoder_input_ids=input_ids, encoder_input_ids=input_ids,
prefix_allowed_tokens_fn=None, prefix_allowed_tokens_fn=None,
logits_processor=logits_processor, logits_processor=logits_processor,
device=input_ids.device,
) )
# 10. prepare stopping criteria # 10. prepare stopping criteria
...@@ -1756,7 +1757,7 @@ class MusicgenForCausalLM(MusicgenPreTrainedModel): ...@@ -1756,7 +1757,7 @@ class MusicgenForCausalLM(MusicgenPreTrainedModel):
elif is_sample_gen_mode: elif is_sample_gen_mode:
# 11. prepare logits warper # 11. prepare logits warper
logits_warper = self._get_logits_warper(generation_config) logits_warper = self._get_logits_warper(generation_config, device=input_ids.device)
# expand input_ids with `num_return_sequences` additional sequences per batch # expand input_ids with `num_return_sequences` additional sequences per batch
input_ids, model_kwargs = self._expand_inputs_for_generation( input_ids, model_kwargs = self._expand_inputs_for_generation(
...@@ -2822,6 +2823,7 @@ class MusicgenForConditionalGeneration(PreTrainedModel): ...@@ -2822,6 +2823,7 @@ class MusicgenForConditionalGeneration(PreTrainedModel):
encoder_input_ids=inputs_tensor, encoder_input_ids=inputs_tensor,
prefix_allowed_tokens_fn=None, prefix_allowed_tokens_fn=None,
logits_processor=logits_processor, logits_processor=logits_processor,
device=input_ids.device,
) )
# 10. prepare stopping criteria # 10. prepare stopping criteria
...@@ -2849,7 +2851,7 @@ class MusicgenForConditionalGeneration(PreTrainedModel): ...@@ -2849,7 +2851,7 @@ class MusicgenForConditionalGeneration(PreTrainedModel):
elif is_sample_gen_mode: elif is_sample_gen_mode:
# 11. prepare logits warper # 11. prepare logits warper
logits_warper = self._get_logits_warper(generation_config) logits_warper = self._get_logits_warper(generation_config, device=input_ids.device)
# expand input_ids with `num_return_sequences` additional sequences per batch # expand input_ids with `num_return_sequences` additional sequences per batch
input_ids, model_kwargs = self._expand_inputs_for_generation( input_ids, model_kwargs = self._expand_inputs_for_generation(
......
...@@ -1666,6 +1666,7 @@ class MusicgenMelodyForCausalLM(MusicgenMelodyPreTrainedModel): ...@@ -1666,6 +1666,7 @@ class MusicgenMelodyForCausalLM(MusicgenMelodyPreTrainedModel):
encoder_input_ids=input_ids, encoder_input_ids=input_ids,
prefix_allowed_tokens_fn=None, prefix_allowed_tokens_fn=None,
logits_processor=logits_processor, logits_processor=logits_processor,
device=input_ids.device,
) )
# 10. prepare stopping criteria # 10. prepare stopping criteria
...@@ -1693,7 +1694,7 @@ class MusicgenMelodyForCausalLM(MusicgenMelodyPreTrainedModel): ...@@ -1693,7 +1694,7 @@ class MusicgenMelodyForCausalLM(MusicgenMelodyPreTrainedModel):
elif is_sample_gen_mode: elif is_sample_gen_mode:
# 11. prepare logits warper # 11. prepare logits warper
logits_warper = self._get_logits_warper(generation_config) logits_warper = self._get_logits_warper(generation_config, device=input_ids.device)
# expand input_ids with `num_return_sequences` additional sequences per batch # expand input_ids with `num_return_sequences` additional sequences per batch
input_ids, model_kwargs = self._expand_inputs_for_generation( input_ids, model_kwargs = self._expand_inputs_for_generation(
...@@ -2681,6 +2682,7 @@ class MusicgenMelodyForConditionalGeneration(PreTrainedModel): ...@@ -2681,6 +2682,7 @@ class MusicgenMelodyForConditionalGeneration(PreTrainedModel):
encoder_input_ids=inputs_tensor, encoder_input_ids=inputs_tensor,
prefix_allowed_tokens_fn=None, prefix_allowed_tokens_fn=None,
logits_processor=logits_processor, logits_processor=logits_processor,
device=input_ids.device,
) )
# 10. prepare stopping criteria # 10. prepare stopping criteria
...@@ -2708,7 +2710,7 @@ class MusicgenMelodyForConditionalGeneration(PreTrainedModel): ...@@ -2708,7 +2710,7 @@ class MusicgenMelodyForConditionalGeneration(PreTrainedModel):
elif is_sample_gen_mode: elif is_sample_gen_mode:
# 11. prepare logits warper # 11. prepare logits warper
logits_warper = self._get_logits_warper(generation_config) logits_warper = self._get_logits_warper(generation_config, device=input_ids.device)
# expand input_ids with `num_return_sequences` additional sequences per batch # expand input_ids with `num_return_sequences` additional sequences per batch
input_ids, model_kwargs = self._expand_inputs_for_generation( input_ids, model_kwargs = self._expand_inputs_for_generation(
......
...@@ -1538,6 +1538,7 @@ class RagTokenForGeneration(RagPreTrainedModel): ...@@ -1538,6 +1538,7 @@ class RagTokenForGeneration(RagPreTrainedModel):
encoder_input_ids=context_input_ids, encoder_input_ids=context_input_ids,
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
logits_processor=logits_processor, logits_processor=logits_processor,
device=input_ids.device,
) )
prepared_stopping_criteria = self._get_stopping_criteria( prepared_stopping_criteria = self._get_stopping_criteria(
......
...@@ -548,13 +548,15 @@ class WhisperGenerationMixin: ...@@ -548,13 +548,15 @@ class WhisperGenerationMixin:
self._check_decoder_input_ids(kwargs=kwargs) self._check_decoder_input_ids(kwargs=kwargs)
# 3. Retrieve logits processors # 3. Retrieve logits processors
device = kwargs["encoder_outputs"][0].device if "encoder_outputs" in kwargs else input_features.device
begin_index = init_tokens.shape[1] begin_index = init_tokens.shape[1]
logits_processor = self._retrieve_logit_processors( logits_processor = self._retrieve_logit_processors(
generation_config=generation_config, generation_config=generation_config,
logits_processor=logits_processor, logits_processor=logits_processor,
begin_index=begin_index, # begin index is index of first generated decoder token begin_index=begin_index, # begin index is index of first generated decoder token
is_shortform=is_shortform, is_shortform=is_shortform,
num_beams=generation_config.num_beams, num_beams=kwargs.get("num_beams", 1),
device=device,
) )
# 5. If we're in shortform mode, simple generate the whole input at once and return the output # 5. If we're in shortform mode, simple generate the whole input at once and return the output
...@@ -1400,7 +1402,9 @@ class WhisperGenerationMixin: ...@@ -1400,7 +1402,9 @@ class WhisperGenerationMixin:
return max_frames, seek return max_frames, seek
def _retrieve_logit_processors(self, generation_config, logits_processor, begin_index, is_shortform, num_beams): def _retrieve_logit_processors(
self, generation_config, logits_processor, begin_index, is_shortform, num_beams, device
):
if generation_config.return_timestamps is True: if generation_config.return_timestamps is True:
timestamp_processor = WhisperTimeStampLogitsProcessor(generation_config, begin_index=begin_index) timestamp_processor = WhisperTimeStampLogitsProcessor(generation_config, begin_index=begin_index)
logits_processor = ( logits_processor = (
...@@ -1408,7 +1412,7 @@ class WhisperGenerationMixin: ...@@ -1408,7 +1412,7 @@ class WhisperGenerationMixin:
) )
if generation_config.suppress_tokens is not None: if generation_config.suppress_tokens is not None:
suppress_tokens_processor = SuppressTokensLogitsProcessor(generation_config.suppress_tokens) suppress_tokens_processor = SuppressTokensLogitsProcessor(generation_config.suppress_tokens, device=device)
logits_processor = ( logits_processor = (
[suppress_tokens_processor] [suppress_tokens_processor]
if logits_processor is None if logits_processor is None
...@@ -1418,7 +1422,7 @@ class WhisperGenerationMixin: ...@@ -1418,7 +1422,7 @@ class WhisperGenerationMixin:
if generation_config.begin_suppress_tokens is not None: if generation_config.begin_suppress_tokens is not None:
begin_suppress_processor = SuppressTokensAtBeginLogitsProcessor( begin_suppress_processor = SuppressTokensAtBeginLogitsProcessor(
generation_config.begin_suppress_tokens, begin_index=begin_index generation_config.begin_suppress_tokens, begin_index=begin_index, device=device
) )
logits_processor = ( logits_processor = (
[begin_suppress_processor] [begin_suppress_processor]
......
...@@ -69,7 +69,7 @@ class LogitsProcessorTest(unittest.TestCase): ...@@ -69,7 +69,7 @@ class LogitsProcessorTest(unittest.TestCase):
batch_size = 4 batch_size = 4
eos_token_id = 0 eos_token_id = 0
min_dist_processor = MinLengthLogitsProcessor(min_length=10, eos_token_id=eos_token_id) min_dist_processor = MinLengthLogitsProcessor(min_length=10, eos_token_id=eos_token_id, device=torch_device)
# check that min length is applied at length 5 # check that min length is applied at length 5
input_ids = ids_tensor((batch_size, 5), vocab_size=20) input_ids = ids_tensor((batch_size, 5), vocab_size=20)
...@@ -91,7 +91,7 @@ class LogitsProcessorTest(unittest.TestCase): ...@@ -91,7 +91,7 @@ class LogitsProcessorTest(unittest.TestCase):
# check that first input is skipped (min new length applying) # check that first input is skipped (min new length applying)
input_ids = ids_tensor((batch_size, 5), vocab_size=20) input_ids = ids_tensor((batch_size, 5), vocab_size=20)
new_min_dist_processor = MinNewTokensLengthLogitsProcessor( new_min_dist_processor = MinNewTokensLengthLogitsProcessor(
prompt_length_to_skip=input_ids.shape[-1], min_new_tokens=3, eos_token_id=eos_token_id prompt_length_to_skip=input_ids.shape[-1], min_new_tokens=3, eos_token_id=eos_token_id, device=torch_device
) )
expected_eos_scores_before_min_length = batch_size * [-float("inf")] expected_eos_scores_before_min_length = batch_size * [-float("inf")]
...@@ -450,7 +450,7 @@ class LogitsProcessorTest(unittest.TestCase): ...@@ -450,7 +450,7 @@ class LogitsProcessorTest(unittest.TestCase):
torch.tensor([[0.0, 0.1, 0.8, 0.1], [0.01, 0.04, 0.9, 0.05]], device=torch_device, dtype=torch.float) torch.tensor([[0.0, 0.1, 0.8, 0.1], [0.01, 0.04, 0.9, 0.05]], device=torch_device, dtype=torch.float)
) )
eta_warp = EtaLogitsWarper(0.0625) eta_warp = EtaLogitsWarper(0.0625, device=torch_device)
filtered_dist = torch.exp(eta_warp(input_ids, dist)) filtered_dist = torch.exp(eta_warp(input_ids, dist))
# dist should be filtered to only keep values with proba >= min(0.0625, sqrt(0.0625) * e^-H(p)) # dist should be filtered to only keep values with proba >= min(0.0625, sqrt(0.0625) * e^-H(p))
...@@ -474,7 +474,7 @@ class LogitsProcessorTest(unittest.TestCase): ...@@ -474,7 +474,7 @@ class LogitsProcessorTest(unittest.TestCase):
ramp_logits[1] = ramp_logits[1] * 100.0 ramp_logits[1] = ramp_logits[1] * 100.0
# make sure at least 2 tokens are kept # make sure at least 2 tokens are kept
eta_warp = EtaLogitsWarper(0.1, min_tokens_to_keep=2, filter_value=0.0) eta_warp = EtaLogitsWarper(0.1, min_tokens_to_keep=2, filter_value=0.0, device=torch_device)
filtered_dist = eta_warp(input_ids, ramp_logits) filtered_dist = eta_warp(input_ids, ramp_logits)
# first batch should keep 2 tokens, second batch would keep only 1, but due to `min_tokens_to_keep=2` keeps 2. # first batch should keep 2 tokens, second batch would keep only 1, but due to `min_tokens_to_keep=2` keeps 2.
...@@ -640,7 +640,7 @@ class LogitsProcessorTest(unittest.TestCase): ...@@ -640,7 +640,7 @@ class LogitsProcessorTest(unittest.TestCase):
scores_comp = scores.clone() scores_comp = scores.clone()
# instantiate all dist processors # instantiate all dist processors
min_dist_proc = MinLengthLogitsProcessor(min_length=10, eos_token_id=eos_token_id) min_dist_proc = MinLengthLogitsProcessor(min_length=10, eos_token_id=eos_token_id, device=torch_device)
temp_dist_warp = TemperatureLogitsWarper(temperature=0.5) temp_dist_warp = TemperatureLogitsWarper(temperature=0.5)
rep_penalty_proc = RepetitionPenaltyLogitsProcessor(penalty=2.0) rep_penalty_proc = RepetitionPenaltyLogitsProcessor(penalty=2.0)
top_k_warp = TopKLogitsWarper(3) top_k_warp = TopKLogitsWarper(3)
...@@ -767,7 +767,9 @@ class LogitsProcessorTest(unittest.TestCase): ...@@ -767,7 +767,9 @@ class LogitsProcessorTest(unittest.TestCase):
eos_token_id = 0 eos_token_id = 0
max_length = 5 max_length = 5
logits_processor = ForcedEOSTokenLogitsProcessor(max_length=max_length, eos_token_id=eos_token_id) logits_processor = ForcedEOSTokenLogitsProcessor(
max_length=max_length, eos_token_id=eos_token_id, device=torch_device
)
# check that all scores are -inf except the eos_token_id when max_length-1 is reached # check that all scores are -inf except the eos_token_id when max_length-1 is reached
input_ids = ids_tensor((batch_size, 4), vocab_size=20) input_ids = ids_tensor((batch_size, 4), vocab_size=20)
...@@ -927,7 +929,7 @@ class LogitsProcessorTest(unittest.TestCase): ...@@ -927,7 +929,7 @@ class LogitsProcessorTest(unittest.TestCase):
scores = self._get_uniform_logits(2, 4) scores = self._get_uniform_logits(2, 4)
scores[0][eos_token_id] = -6 ## less than log(min_eos_p) scores[0][eos_token_id] = -6 ## less than log(min_eos_p)
esp = BarkEosPrioritizerLogitsProcessor(eos_token_id=eos_token_id, min_eos_p=min_eos_p) esp = BarkEosPrioritizerLogitsProcessor(eos_token_id=eos_token_id, min_eos_p=min_eos_p, device=torch_device)
actual_scores = esp(input_ids, scores) actual_scores = esp(input_ids, scores)
expected_scores_list = [ expected_scores_list = [
scores[0].tolist(), scores[0].tolist(),
...@@ -943,7 +945,7 @@ class LogitsProcessorTest(unittest.TestCase): ...@@ -943,7 +945,7 @@ class LogitsProcessorTest(unittest.TestCase):
scores = self._get_uniform_logits(2, 4) scores = self._get_uniform_logits(2, 4)
scores[0][eos_token_id] = -6 ## less than log(min_eos_p) scores[0][eos_token_id] = -6 ## less than log(min_eos_p)
esp = BarkEosPrioritizerLogitsProcessor(eos_token_id=eos_token_id, min_eos_p=min_eos_p) esp = BarkEosPrioritizerLogitsProcessor(eos_token_id=eos_token_id, min_eos_p=min_eos_p, device=torch_device)
actual_scores = esp(input_ids, scores) actual_scores = esp(input_ids, scores)
expected_scores_list = [ expected_scores_list = [
scores[0].tolist(), scores[0].tolist(),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment