Unverified Commit 7130a22d authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Generate: consistently handle special tokens as tensors (#30624)



* tmp commit

* [test_all] mvp

* missing not

* [test_all] final test fixes

* fix musicgen_melody and rag

* [test_all] empty commit

* PR comments

* Update src/transformers/generation/utils.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

---------
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>
parent 5413b898
......@@ -218,8 +218,8 @@ class BeamSearchScorer(BeamScorer):
next_scores: torch.FloatTensor,
next_tokens: torch.LongTensor,
next_indices: torch.LongTensor,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[Union[int, List[int]]] = None,
pad_token_id: Optional[Union[int, torch.Tensor]] = None,
eos_token_id: Optional[Union[int, List[int], torch.Tensor]] = None,
beam_indices: Optional[torch.LongTensor] = None,
group_index: Optional[int] = 0,
decoder_prompt_len: Optional[int] = 0,
......@@ -245,8 +245,10 @@ class BeamSearchScorer(BeamScorer):
next_beam_tokens = torch.zeros((batch_size, self.group_size), dtype=next_tokens.dtype, device=device)
next_beam_indices = torch.zeros((batch_size, self.group_size), dtype=next_indices.dtype, device=device)
if eos_token_id is not None and not isinstance(eos_token_id, torch.Tensor):
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
eos_token_id = torch.tensor(eos_token_id)
for batch_idx in range(batch_size):
batch_group_idx = batch_idx * self.num_beam_groups + group_index
......@@ -322,15 +324,17 @@ class BeamSearchScorer(BeamScorer):
final_beam_tokens: torch.LongTensor,
final_beam_indices: torch.LongTensor,
max_length: int,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[Union[int, List[int]]] = None,
pad_token_id: Optional[Union[int, torch.Tensor]] = None,
eos_token_id: Optional[Union[int, List[int], torch.Tensor]] = None,
beam_indices: Optional[torch.LongTensor] = None,
decoder_prompt_len: Optional[int] = 0,
) -> Tuple[torch.LongTensor]:
batch_size = len(self._beam_hyps) // self.num_beam_groups
if eos_token_id is not None and not isinstance(eos_token_id, torch.Tensor):
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
eos_token_id = torch.tensor(eos_token_id)
# finalize all open beam hypotheses and add to generated hypotheses
for batch_group_idx, beam_hyp in enumerate(self._beam_hyps):
......@@ -513,8 +517,8 @@ class ConstrainedBeamSearchScorer(BeamScorer):
next_tokens: torch.LongTensor,
next_indices: torch.LongTensor,
scores_for_all_vocab: torch.FloatTensor,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[Union[int, List[int]]] = None,
pad_token_id: Optional[Union[int, torch.Tensor]] = None,
eos_token_id: Optional[Union[int, List[int], torch.Tensor]] = None,
beam_indices: Optional[torch.LongTensor] = None,
decoder_prompt_len: Optional[int] = 0,
) -> Tuple[torch.Tensor]:
......@@ -578,8 +582,10 @@ class ConstrainedBeamSearchScorer(BeamScorer):
next_beam_tokens = torch.zeros((batch_size, self.group_size), dtype=next_tokens.dtype, device=device)
next_beam_indices = torch.zeros((batch_size, self.group_size), dtype=next_indices.dtype, device=device)
if eos_token_id is not None and not isinstance(eos_token_id, torch.Tensor):
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
eos_token_id = torch.tensor(eos_token_id)
for batch_idx, beam_hyp in enumerate(self._beam_hyps):
if self._done[batch_idx]:
......@@ -811,15 +817,17 @@ class ConstrainedBeamSearchScorer(BeamScorer):
final_beam_tokens: torch.LongTensor,
final_beam_indices: torch.LongTensor,
max_length: int,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[Union[int, List[int]]] = None,
pad_token_id: Optional[Union[int, torch.Tensor]] = None,
eos_token_id: Optional[Union[int, List[int], torch.Tensor]] = None,
beam_indices: Optional[torch.LongTensor] = None,
decoder_prompt_len: Optional[int] = 0,
) -> Tuple[torch.LongTensor]:
batch_size = len(self._beam_hyps)
if eos_token_id is not None and not isinstance(eos_token_id, torch.Tensor):
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
eos_token_id = torch.tensor(eos_token_id)
# finalize all open beam hypotheses and add to generated hypotheses
for batch_idx, beam_hyp in enumerate(self._beam_hyps):
......
......@@ -108,8 +108,8 @@ class MinLengthLogitsProcessor(LogitsProcessor):
Args:
min_length (`int`):
The minimum length below which the score of `eos_token_id` is set to `-float("Inf")`.
eos_token_id (`Union[int, List[int]]`):
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
eos_token_id (`Union[int, List[int], torch.Tensor]`):
The id(s) of the *end-of-sequence* token.
Examples:
......@@ -137,14 +137,14 @@ class MinLengthLogitsProcessor(LogitsProcessor):
```
"""
def __init__(self, min_length: int, eos_token_id: Union[int, List[int]]):
def __init__(self, min_length: int, eos_token_id: Union[int, List[int], torch.Tensor]):
if not isinstance(min_length, int) or min_length < 0:
raise ValueError(f"`min_length` has to be a non-negative integer, but is {min_length}")
if not isinstance(eos_token_id, torch.Tensor):
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
if not all(isinstance(i, int) for i in eos_token_id) or any(i < 0 for i in eos_token_id):
logger.warning(f"`eos_token_id` has to be a list of positive integers, but is {eos_token_id}")
eos_token_id = torch.tensor(eos_token_id)
self.min_length = min_length
self.eos_token_id = eos_token_id
......@@ -152,8 +152,8 @@ class MinLengthLogitsProcessor(LogitsProcessor):
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
vocab_tensor = torch.arange(scores.shape[-1], device=scores.device)
eos_token_id = torch.tensor(self.eos_token_id, device=scores.device)
eos_token_mask = torch.isin(vocab_tensor, eos_token_id)
self.eos_token_id = self.eos_token_id.to(scores.device)
eos_token_mask = torch.isin(vocab_tensor, self.eos_token_id)
scores_processed = scores.clone()
if input_ids.shape[-1] < self.min_length:
scores_processed = torch.where(eos_token_mask, -math.inf, scores)
......@@ -171,8 +171,8 @@ class MinNewTokensLengthLogitsProcessor(LogitsProcessor):
input length.
min_new_tokens (`int`):
The minimum *new* tokens length below which the score of `eos_token_id` is set to `-float("Inf")`.
eos_token_id (`Union[int, List[int]]`):
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
eos_token_id (`Union[int, List[int], torch.Tensor]`):
The id(s) of the *end-of-sequence* token.
Examples:
......@@ -195,7 +195,9 @@ class MinNewTokensLengthLogitsProcessor(LogitsProcessor):
```
"""
def __init__(self, prompt_length_to_skip: int, min_new_tokens: int, eos_token_id: Union[int, List[int]]):
def __init__(
self, prompt_length_to_skip: int, min_new_tokens: int, eos_token_id: Union[int, List[int], torch.Tensor]
):
for arg_name, arg_value in [
("prompt_length_to_skip", prompt_length_to_skip),
("min_new_tokens", min_new_tokens),
......@@ -203,10 +205,10 @@ class MinNewTokensLengthLogitsProcessor(LogitsProcessor):
if not isinstance(arg_value, int) or arg_value < 0:
raise ValueError(f"`{arg_name}` has to be a positive integer, but is {arg_value}")
if not isinstance(eos_token_id, torch.Tensor):
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
if not all(isinstance(i, int) for i in eos_token_id) or any(i < 0 for i in eos_token_id):
logger.warning(f"`eos_token_id` has to be a list of positive integers, but is {eos_token_id}")
eos_token_id = torch.tensor(eos_token_id)
self.prompt_length_to_skip = prompt_length_to_skip
self.min_new_tokens = min_new_tokens
......@@ -217,8 +219,8 @@ class MinNewTokensLengthLogitsProcessor(LogitsProcessor):
new_tokens_length = input_ids.shape[-1] - self.prompt_length_to_skip
scores_processed = scores.clone()
vocab_tensor = torch.arange(scores.shape[-1], device=scores.device)
eos_token_id = torch.tensor(self.eos_token_id, device=scores.device)
eos_token_mask = torch.isin(vocab_tensor, eos_token_id)
self.eos_token_id = self.eos_token_id.to(scores.device)
eos_token_mask = torch.isin(vocab_tensor, self.eos_token_id)
if new_tokens_length < self.min_new_tokens:
scores_processed = torch.where(eos_token_mask, -math.inf, scores)
......@@ -1195,8 +1197,8 @@ class NoBadWordsLogitsProcessor(SequenceBiasLogitsProcessor):
Args:
bad_words_ids (`List[List[int]]`):
List of list of token ids that are not allowed to be generated.
eos_token_id (`Union[int, List[int]]`):
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
eos_token_id (`Union[int, List[int], torch.Tensor]`, *optional*):
The id(s) of the *end-of-sequence* token.
Examples:
......@@ -1233,15 +1235,19 @@ class NoBadWordsLogitsProcessor(SequenceBiasLogitsProcessor):
```
"""
def __init__(self, bad_words_ids: List[List[int]], eos_token_id: Union[int, List[int]]):
def __init__(
self, bad_words_ids: List[List[int]], eos_token_id: Optional[Union[int, List[int], torch.Tensor]] = None
):
self.bad_word_ids = bad_words_ids
self._validate_arguments()
# Filter EOS token from bad_words_ids
if eos_token_id is None:
eos_token_id = []
if eos_token_id is not None:
if not isinstance(eos_token_id, torch.Tensor):
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
eos_token_id = torch.tensor(eos_token_id)
bad_words_ids = list(
filter(lambda bad_token_seq: all(bad_token_seq != [i] for i in eos_token_id), bad_words_ids)
)
......@@ -1522,9 +1528,8 @@ class ForcedEOSTokenLogitsProcessor(LogitsProcessor):
Args:
max_length (`int`):
The maximum length of the sequence to be generated.
eos_token_id (`Union[int, List[int]]`):
The id of the token to force as the last generated token when `max_length` is reached. Optionally, use a
list to set multiple *end-of-sequence* tokens.
eos_token_id (`Union[int, List[int], torch.Tensor]`):
The id(s) of the *end-of-sequence* token.
Examples:
......@@ -1548,15 +1553,22 @@ class ForcedEOSTokenLogitsProcessor(LogitsProcessor):
```
"""
def __init__(self, max_length: int, eos_token_id: Union[int, List[int]]):
def __init__(self, max_length: int, eos_token_id: Union[int, List[int], torch.Tensor]):
self.max_length = max_length
if not isinstance(eos_token_id, torch.Tensor):
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
eos_token_id = torch.tensor(eos_token_id)
self.eos_token_id = eos_token_id
if torch.is_floating_point(eos_token_id) or (eos_token_id < 0).any():
raise ValueError(f"`eos_token_id` has to be a list of positive integers, but is {eos_token_id}")
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
cur_len = input_ids.shape[-1]
self.eos_token_id = self.eos_token_id.to(scores.device)
scores_processed = scores
if cur_len == self.max_length - 1:
scores_processed = torch.full_like(scores, -math.inf)
......@@ -1595,8 +1607,8 @@ class ExponentialDecayLengthPenalty(LogitsProcessor):
exponential_decay_length_penalty (`tuple(int, float)`):
This tuple shall consist of: `(start_index, decay_factor)` where `start_index` indicates where penalty
starts and `decay_factor` represents the factor of exponential decay
eos_token_id (`Union[int, List[int]]`):
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
eos_token_id (`Union[int, List[int], torch.Tensor]`):
The id(s) of the *end-of-sequence* token.
input_ids_seq_length (`int`):
The length of the input sequence.
......@@ -1656,26 +1668,32 @@ class ExponentialDecayLengthPenalty(LogitsProcessor):
def __init__(
self,
exponential_decay_length_penalty: Tuple[int, float],
eos_token_id: Union[int, List[int]],
eos_token_id: Union[int, List[int], torch.Tensor],
input_ids_seq_length: int,
):
self.regulation_start = exponential_decay_length_penalty[0] + input_ids_seq_length
self.regulation_factor = exponential_decay_length_penalty[1]
if not isinstance(eos_token_id, torch.Tensor):
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
eos_token_id = torch.tensor(eos_token_id)
self.eos_token_id = eos_token_id
if torch.is_floating_point(eos_token_id) or (eos_token_id < 0).any():
raise ValueError(f"`eos_token_id` has to be a list of positive integers, but is {eos_token_id}")
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
cur_len = input_ids.shape[-1]
self.eos_token_id = self.eos_token_id.to(scores.device)
penalties = torch.zeros_like(scores)
scores_processed = scores
if cur_len > self.regulation_start:
for i in self.eos_token_id:
penalty_idx = cur_len - self.regulation_start
# To support negative logits we compute the penalty of the absolute value and add to the original logit
penalty = torch.abs(scores[:, i]) * (pow(self.regulation_factor, penalty_idx) - 1)
penalties[:, i] = penalty
penalty = torch.abs(scores[:, self.eos_token_id]) * (pow(self.regulation_factor, penalty_idx) - 1)
penalties[:, self.eos_token_id] = penalty
scores_processed = scores + penalties
return scores_processed
......@@ -1753,7 +1771,7 @@ class SuppressTokensAtBeginLogitsProcessor(LogitsProcessor):
"""
def __init__(self, begin_suppress_tokens, begin_index):
self.begin_suppress_tokens = list(begin_suppress_tokens)
self.begin_suppress_tokens = torch.tensor(list(begin_suppress_tokens))
self.begin_index = begin_index
def set_begin_index(self, begin_index):
......@@ -1762,8 +1780,8 @@ class SuppressTokensAtBeginLogitsProcessor(LogitsProcessor):
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
vocab_tensor = torch.arange(scores.shape[-1], device=scores.device)
begin_suppress_tokens = torch.tensor(self.begin_suppress_tokens, device=scores.device)
suppress_token_mask = torch.isin(vocab_tensor, begin_suppress_tokens)
self.begin_suppress_tokens = self.begin_suppress_tokens.to(scores.device)
suppress_token_mask = torch.isin(vocab_tensor, self.begin_suppress_tokens)
scores_processed = scores
if input_ids.shape[-1] == self.begin_index:
scores_processed = torch.where(suppress_token_mask, -float("inf"), scores)
......@@ -1801,13 +1819,13 @@ class SuppressTokensLogitsProcessor(LogitsProcessor):
"""
def __init__(self, suppress_tokens):
self.suppress_tokens = list(suppress_tokens)
self.suppress_tokens = torch.tensor(list(suppress_tokens))
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
vocab_tensor = torch.arange(scores.shape[-1], device=scores.device)
suppress_tokens = torch.tensor(self.suppress_tokens, device=scores.device)
suppress_token_mask = torch.isin(vocab_tensor, suppress_tokens)
self.suppress_tokens = self.suppress_tokens.to(scores.device)
suppress_token_mask = torch.isin(vocab_tensor, self.suppress_tokens)
scores = torch.where(suppress_token_mask, -float("inf"), scores)
return scores
......@@ -2268,16 +2286,22 @@ class BarkEosPrioritizerLogitsProcessor(LogitsProcessor):
</Tip>
Args:
eos_token_id (`Union[int, List[int]]`):
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
eos_token_id (`Union[int, List[int], torch.Tensor]`):
The id(s) of the *end-of-sequence* token.
min_eos_p (`float`, *optional*):
Minimum end of speech threshold.
"""
def __init__(self, eos_token_id: Union[int, List[int]], min_eos_p: float):
def __init__(self, eos_token_id: Union[int, List[int], torch.Tensor], min_eos_p: float):
if not isinstance(eos_token_id, torch.Tensor):
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
eos_token_id = torch.tensor(eos_token_id)
self.eos_token_id = eos_token_id
if torch.is_floating_point(eos_token_id) or (eos_token_id < 0).any():
raise ValueError(f"`eos_token_id` has to be a list of positive integers, but is {eos_token_id}")
if min_eos_p is not None and min_eos_p <= 0:
raise ValueError(f"`min_eos_p` has to be a positive float, but is {min_eos_p}")
self.min_eos_p = min_eos_p
......@@ -2285,6 +2309,7 @@ class BarkEosPrioritizerLogitsProcessor(LogitsProcessor):
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
scores_processed = scores
self.eos_token_id = self.eos_token_id.to(scores.device)
if self.min_eos_p:
probs = torch.nn.functional.softmax(scores.float(), dim=-1)
# create scores full of -inf except for the eos_token_id
......
......@@ -470,29 +470,32 @@ class EosTokenCriteria(StoppingCriteria):
By default, it uses the `model.generation_config.eos_token_id`.
Args:
eos_token_id (`Union[int, List[int]]`):
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
eos_token_id (`Union[int, List[int], torch.Tensor]`):
The id(s) of the *end-of-sequence* token.
"""
def __init__(self, eos_token_id: Union[int, List[int]]):
def __init__(self, eos_token_id: Union[int, List[int], torch.Tensor]):
if not isinstance(eos_token_id, torch.Tensor):
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
self.eos_token_id = torch.tensor(eos_token_id)
eos_token_id = torch.tensor(eos_token_id)
self.eos_token_id = eos_token_id
@add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.BoolTensor:
self.eos_token_id = self.eos_token_id.to(input_ids.device)
if input_ids.device.type == "mps":
# https://github.com/pytorch/pytorch/issues/77764#issuecomment-2067838075
is_done = (
input_ids[:, -1]
.tile(self.eos_token_id.shape[0], 1)
.eq(self.eos_token_id.unsqueeze(1).to(input_ids.device))
.eq(self.eos_token_id.unsqueeze(1))
.sum(dim=0)
.bool()
.squeeze()
)
else:
is_done = torch.isin(input_ids[:, -1], self.eos_token_id.to(input_ids.device))
is_done = torch.isin(input_ids[:, -1], self.eos_token_id)
return is_done
......
This diff is collapsed.
......@@ -18,7 +18,7 @@ import inspect
import math
import random
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
......@@ -2587,6 +2587,24 @@ class MusicgenForConditionalGeneration(PreTrainedModel):
break
return torch.ones((batch_size, 1), dtype=torch.long, device=self.device) * bos_token_id
def _get_decoder_start_token_id(
self, decoder_start_token_id: Union[int, List[int]] = None, bos_token_id: int = None
) -> int:
decoder_start_token_id = (
decoder_start_token_id
if decoder_start_token_id is not None
else self.generation_config.decoder_start_token_id
)
bos_token_id = bos_token_id if bos_token_id is not None else self.generation_config.bos_token_id
if decoder_start_token_id is not None:
return decoder_start_token_id
elif bos_token_id is not None:
return bos_token_id
raise ValueError(
"`decoder_start_token_id` or `bos_token_id` has to be defined for encoder-decoder generation."
)
@torch.no_grad()
def generate(
self,
......
......@@ -18,7 +18,7 @@ import inspect
import math
import random
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
......@@ -2452,6 +2452,25 @@ class MusicgenMelodyForConditionalGeneration(PreTrainedModel):
param.requires_grad = False
self.text_encoder._requires_grad = False
# Copied from transformers.models.musicgen.modeling_musicgen.MusicgenForConditionalGeneration._get_decoder_start_token_id
def _get_decoder_start_token_id(
self, decoder_start_token_id: Union[int, List[int]] = None, bos_token_id: int = None
) -> int:
decoder_start_token_id = (
decoder_start_token_id
if decoder_start_token_id is not None
else self.generation_config.decoder_start_token_id
)
bos_token_id = bos_token_id if bos_token_id is not None else self.generation_config.bos_token_id
if decoder_start_token_id is not None:
return decoder_start_token_id
elif bos_token_id is not None:
return bos_token_id
raise ValueError(
"`decoder_start_token_id` or `bos_token_id` has to be defined for encoder-decoder generation."
)
@torch.no_grad()
def generate(
self,
......
......@@ -1458,6 +1458,9 @@ class RagTokenForGeneration(RagPreTrainedModel):
generation_config = copy.deepcopy(generation_config)
model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs
kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None
self._prepare_special_tokens(generation_config, kwargs_has_attention_mask)
# set default parameters
n_docs = n_docs if n_docs is not None else self.config.n_docs
......
......@@ -14,6 +14,7 @@
# limitations under the License.
import copy
import inspect
import tempfile
import unittest
......@@ -168,7 +169,9 @@ class GenerationTesterMixin:
encoder_outputs["last_hidden_state"] = encoder_outputs.last_hidden_state.repeat_interleave(
num_interleave, dim=0
)
input_ids = torch.zeros_like(input_ids[:, :1]) + model._get_decoder_start_token_id()
generation_config = copy.deepcopy(model.generation_config)
model._prepare_special_tokens(generation_config)
input_ids = torch.zeros_like(input_ids[:, :1]) + generation_config.decoder_start_token_id
attention_mask = None
return encoder_outputs, input_ids, attention_mask
......
......@@ -414,9 +414,11 @@ class SeamlessM4TModelWithSpeechInputTest(ModelTesterMixin, unittest.TestCase):
encoder_outputs["last_hidden_state"] = encoder_outputs.last_hidden_state.repeat_interleave(
num_interleave, dim=0
)
generation_config = copy.deepcopy(model.generation_config)
model._prepare_special_tokens(generation_config)
input_ids = (
torch.zeros(input_ids.shape[:2], dtype=torch.int64, layout=input_ids.layout, device=input_ids.device)
+ model._get_decoder_start_token_id()
+ generation_config.decoder_start_token_id
)
attention_mask = None
return encoder_outputs, input_ids, attention_mask
......
......@@ -430,9 +430,11 @@ class SeamlessM4Tv2ModelWithSpeechInputTest(ModelTesterMixin, unittest.TestCase)
encoder_outputs["last_hidden_state"] = encoder_outputs.last_hidden_state.repeat_interleave(
num_interleave, dim=0
)
generation_config = copy.deepcopy(model.generation_config)
model._prepare_special_tokens(generation_config)
input_ids = (
torch.zeros(input_ids.shape[:2], dtype=torch.int64, layout=input_ids.layout, device=input_ids.device)
+ model._get_decoder_start_token_id()
+ generation_config.decoder_start_token_id
)
attention_mask = None
return encoder_outputs, input_ids, attention_mask
......
......@@ -645,7 +645,9 @@ class Speech2TextModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTest
num_interleave, dim=0
)
input_ids = input_ids[:, :, 0]
input_ids = torch.zeros_like(input_ids[:, :1], dtype=torch.long) + model._get_decoder_start_token_id()
generation_config = copy.deepcopy(model.generation_config)
model._prepare_special_tokens(generation_config)
input_ids = torch.zeros_like(input_ids[:, :1]) + generation_config.decoder_start_token_id
attention_mask = None
return encoder_outputs, input_ids, attention_mask
......
......@@ -833,10 +833,10 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
encoder_outputs["last_hidden_state"] = encoder_outputs.last_hidden_state.repeat_interleave(
num_interleave, dim=0
)
generation_config = copy.deepcopy(model.generation_config)
model._prepare_special_tokens(generation_config)
input_ids = input_ids[:, :, 0]
input_ids = torch.zeros_like(input_ids[:, :1], dtype=torch.long) + torch.tensor(
[model._get_decoder_start_token_id()], device=input_ids.device
)
input_ids = torch.zeros_like(input_ids[:, :1], dtype=torch.long) + generation_config.decoder_start_token_id
attention_mask = None
return encoder_outputs, input_ids, attention_mask
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment