"...resnet50_tensorflow.git" did not exist on "15dddd8c6a5fcab414570a3b015bf12157ff3a48"
Unverified Commit 83238eee authored by Raushan Turganbay's avatar Raushan Turganbay Committed by GitHub
Browse files

Pass device in Logits Processor's init (#29804)



* add device in logits processor

* remove device when not needed

* codestyle

* tests

* forgot `melody` version

* Update src/transformers/models/whisper/generation_whisper.py
Co-authored-by: default avatarJoao Gante <joaofranciscocardosogante@gmail.com>

* codestyle

* updates

---------
Co-authored-by: default avatarJoao Gante <joaofranciscocardosogante@gmail.com>
parent c73ee133
......@@ -110,6 +110,8 @@ class MinLengthLogitsProcessor(LogitsProcessor):
The minimum length below which the score of `eos_token_id` is set to `-float("Inf")`.
eos_token_id (`Union[int, List[int], torch.Tensor]`):
The id(s) of the *end-of-sequence* token.
device (`str`, *optional*, defaults to `"cpu"`):
The device to allocate the tensors.
Examples:
......@@ -137,14 +139,14 @@ class MinLengthLogitsProcessor(LogitsProcessor):
```
"""
def __init__(self, min_length: int, eos_token_id: Union[int, List[int], torch.Tensor]):
def __init__(self, min_length: int, eos_token_id: Union[int, List[int], torch.Tensor], device: str = "cpu"):
if not isinstance(min_length, int) or min_length < 0:
raise ValueError(f"`min_length` has to be a non-negative integer, but is {min_length}")
if not isinstance(eos_token_id, torch.Tensor):
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
eos_token_id = torch.tensor(eos_token_id)
eos_token_id = torch.tensor(eos_token_id, device=device)
self.min_length = min_length
self.eos_token_id = eos_token_id
......@@ -152,7 +154,6 @@ class MinLengthLogitsProcessor(LogitsProcessor):
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
vocab_tensor = torch.arange(scores.shape[-1], device=scores.device)
self.eos_token_id = self.eos_token_id.to(scores.device)
eos_token_mask = torch.isin(vocab_tensor, self.eos_token_id)
scores_processed = scores.clone()
if input_ids.shape[-1] < self.min_length:
......@@ -173,6 +174,8 @@ class MinNewTokensLengthLogitsProcessor(LogitsProcessor):
The minimum *new* tokens length below which the score of `eos_token_id` is set to `-float("Inf")`.
eos_token_id (`Union[int, List[int], torch.Tensor]`):
The id(s) of the *end-of-sequence* token.
device (`str`, *optional*, defaults to `"cpu"`):
The device to allocate the tensors.
Examples:
......@@ -196,7 +199,11 @@ class MinNewTokensLengthLogitsProcessor(LogitsProcessor):
"""
def __init__(
self, prompt_length_to_skip: int, min_new_tokens: int, eos_token_id: Union[int, List[int], torch.Tensor]
self,
prompt_length_to_skip: int,
min_new_tokens: int,
eos_token_id: Union[int, List[int], torch.Tensor],
device: str = "cpu",
):
for arg_name, arg_value in [
("prompt_length_to_skip", prompt_length_to_skip),
......@@ -208,7 +215,7 @@ class MinNewTokensLengthLogitsProcessor(LogitsProcessor):
if not isinstance(eos_token_id, torch.Tensor):
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
eos_token_id = torch.tensor(eos_token_id)
eos_token_id = torch.tensor(eos_token_id, device=device)
self.prompt_length_to_skip = prompt_length_to_skip
self.min_new_tokens = min_new_tokens
......@@ -219,7 +226,6 @@ class MinNewTokensLengthLogitsProcessor(LogitsProcessor):
new_tokens_length = input_ids.shape[-1] - self.prompt_length_to_skip
scores_processed = scores.clone()
vocab_tensor = torch.arange(scores.shape[-1], device=scores.device)
self.eos_token_id = self.eos_token_id.to(scores.device)
eos_token_mask = torch.isin(vocab_tensor, self.eos_token_id)
if new_tokens_length < self.min_new_tokens:
scores_processed = torch.where(eos_token_mask, -math.inf, scores)
......@@ -779,6 +785,8 @@ class EtaLogitsWarper(LogitsWarper):
Specifies the minimum number of tokens that must be kept for generation, regardless of their probabilities.
For example, if `min_tokens_to_keep` is set to 1, at least one token will always be kept for generation,
even if all tokens have probabilities below the cutoff `eta`.
device (`str`, *optional*, defaults to `"cpu"`):
The device to allocate the tensors.
Examples:
```python
......@@ -806,7 +814,9 @@ class EtaLogitsWarper(LogitsWarper):
```
"""
def __init__(self, epsilon: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
def __init__(
self, epsilon: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1, device: str = "cpu"
):
epsilon = float(epsilon)
if epsilon <= 0 or epsilon >= 1:
raise ValueError(f"`eta_cutoff` has to be a float > 0 and < 1, but is {epsilon}")
......@@ -817,13 +827,12 @@ class EtaLogitsWarper(LogitsWarper):
f"`min_tokens_to_keep` has to be a strictly positive integer, but is {min_tokens_to_keep}"
)
self.epsilon = torch.tensor(epsilon)
self.epsilon = torch.tensor(epsilon, device=device)
self.filter_value = filter_value
self.min_tokens_to_keep = min_tokens_to_keep
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
# Calculate the adaptive cutoff
probabilities = scores.softmax(dim=-1)
entropy = torch.distributions.Categorical(logits=scores).entropy()
eta = torch.min(self.epsilon, torch.sqrt(self.epsilon) * torch.exp(-entropy))[..., None]
......@@ -1530,6 +1539,8 @@ class ForcedEOSTokenLogitsProcessor(LogitsProcessor):
The maximum length of the sequence to be generated.
eos_token_id (`Union[int, List[int], torch.Tensor]`):
The id(s) of the *end-of-sequence* token.
device (`str`, *optional*, defaults to `"cpu"`):
The device to allocate the tensors.
Examples:
......@@ -1553,13 +1564,13 @@ class ForcedEOSTokenLogitsProcessor(LogitsProcessor):
```
"""
def __init__(self, max_length: int, eos_token_id: Union[int, List[int], torch.Tensor]):
def __init__(self, max_length: int, eos_token_id: Union[int, List[int], torch.Tensor], device: str = "cpu"):
self.max_length = max_length
if not isinstance(eos_token_id, torch.Tensor):
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
eos_token_id = torch.tensor(eos_token_id)
eos_token_id = torch.tensor(eos_token_id, device=device)
self.eos_token_id = eos_token_id
if torch.is_floating_point(eos_token_id) or (eos_token_id < 0).any():
......@@ -1568,7 +1579,6 @@ class ForcedEOSTokenLogitsProcessor(LogitsProcessor):
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
cur_len = input_ids.shape[-1]
self.eos_token_id = self.eos_token_id.to(scores.device)
scores_processed = scores
if cur_len == self.max_length - 1:
scores_processed = torch.full_like(scores, -math.inf)
......@@ -1770,8 +1780,8 @@ class SuppressTokensAtBeginLogitsProcessor(LogitsProcessor):
```
"""
def __init__(self, begin_suppress_tokens, begin_index):
self.begin_suppress_tokens = torch.tensor(list(begin_suppress_tokens))
def __init__(self, begin_suppress_tokens, begin_index, device: str = "cpu"):
self.begin_suppress_tokens = torch.tensor(list(begin_suppress_tokens), device=device)
self.begin_index = begin_index
def set_begin_index(self, begin_index):
......@@ -1780,7 +1790,6 @@ class SuppressTokensAtBeginLogitsProcessor(LogitsProcessor):
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
vocab_tensor = torch.arange(scores.shape[-1], device=scores.device)
self.begin_suppress_tokens = self.begin_suppress_tokens.to(scores.device)
suppress_token_mask = torch.isin(vocab_tensor, self.begin_suppress_tokens)
scores_processed = scores
if input_ids.shape[-1] == self.begin_index:
......@@ -1818,13 +1827,12 @@ class SuppressTokensLogitsProcessor(LogitsProcessor):
```
"""
def __init__(self, suppress_tokens):
self.suppress_tokens = torch.tensor(list(suppress_tokens))
def __init__(self, suppress_tokens, device: str = "cpu"):
self.suppress_tokens = torch.tensor(list(suppress_tokens), device=device)
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
vocab_tensor = torch.arange(scores.shape[-1], device=scores.device)
self.suppress_tokens = self.suppress_tokens.to(scores.device)
suppress_token_mask = torch.isin(vocab_tensor, self.suppress_tokens)
scores = torch.where(suppress_token_mask, -float("inf"), scores)
return scores
......@@ -1915,7 +1923,10 @@ class WhisperTimeStampLogitsProcessor(LogitsProcessor):
"""
def __init__(
self, generate_config, begin_index: Optional[int] = None, _detect_timestamp_from_logprob: Optional[bool] = None
self,
generate_config,
begin_index: Optional[int] = None,
_detect_timestamp_from_logprob: Optional[bool] = None,
): # support for the kwargs
self.no_timestamps_token_id = generate_config.no_timestamps_token_id
self.timestamp_begin = generate_config.no_timestamps_token_id + 1
......@@ -2292,11 +2303,11 @@ class BarkEosPrioritizerLogitsProcessor(LogitsProcessor):
Minimum end of speech threshold.
"""
def __init__(self, eos_token_id: Union[int, List[int], torch.Tensor], min_eos_p: float):
def __init__(self, eos_token_id: Union[int, List[int], torch.Tensor], min_eos_p: float, device: str = "cpu"):
if not isinstance(eos_token_id, torch.Tensor):
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
eos_token_id = torch.tensor(eos_token_id)
eos_token_id = torch.tensor(eos_token_id, device=device)
self.eos_token_id = eos_token_id
if torch.is_floating_point(eos_token_id) or (eos_token_id < 0).any():
......@@ -2309,7 +2320,6 @@ class BarkEosPrioritizerLogitsProcessor(LogitsProcessor):
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
scores_processed = scores
self.eos_token_id = self.eos_token_id.to(scores.device)
if self.min_eos_p:
probs = torch.nn.functional.softmax(scores.float(), dim=-1)
# create scores full of -inf except for the eos_token_id
......
......@@ -723,6 +723,7 @@ class GenerationMixin:
def _get_logits_warper(
self,
generation_config: GenerationConfig,
device: str,
) -> LogitsProcessorList:
"""
This class returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsWarper`] instances
......@@ -765,7 +766,9 @@ class GenerationMixin:
)
if generation_config.eta_cutoff is not None and 0.0 < generation_config.eta_cutoff < 1.0:
warpers.append(
EtaLogitsWarper(epsilon=generation_config.eta_cutoff, min_tokens_to_keep=min_tokens_to_keep)
EtaLogitsWarper(
epsilon=generation_config.eta_cutoff, min_tokens_to_keep=min_tokens_to_keep, device=device
)
)
# `LogitNormalization` should always be the last logit processor, when present
if generation_config.renormalize_logits is True:
......@@ -818,7 +821,8 @@ class GenerationMixin:
):
processors.append(
EncoderRepetitionPenaltyLogitsProcessor(
penalty=generation_config.encoder_repetition_penalty, encoder_input_ids=encoder_input_ids
penalty=generation_config.encoder_repetition_penalty,
encoder_input_ids=encoder_input_ids,
)
)
if generation_config.repetition_penalty is not None and generation_config.repetition_penalty != 1.0:
......@@ -830,18 +834,30 @@ class GenerationMixin:
and generation_config.encoder_no_repeat_ngram_size > 0
):
processors.append(
EncoderNoRepeatNGramLogitsProcessor(generation_config.encoder_no_repeat_ngram_size, encoder_input_ids)
EncoderNoRepeatNGramLogitsProcessor(
generation_config.encoder_no_repeat_ngram_size,
encoder_input_ids,
)
)
if generation_config.bad_words_ids is not None:
processors.append(
NoBadWordsLogitsProcessor(generation_config.bad_words_ids, generation_config.eos_token_id)
NoBadWordsLogitsProcessor(
generation_config.bad_words_ids,
generation_config.eos_token_id,
)
)
if (
generation_config.min_length is not None
and generation_config.eos_token_id is not None
and generation_config.min_length > 0
):
processors.append(MinLengthLogitsProcessor(generation_config.min_length, generation_config.eos_token_id))
processors.append(
MinLengthLogitsProcessor(
generation_config.min_length,
generation_config.eos_token_id,
device=device,
)
)
if (
generation_config.min_new_tokens is not None
and generation_config.eos_token_id is not None
......@@ -849,20 +865,32 @@ class GenerationMixin:
):
processors.append(
MinNewTokensLengthLogitsProcessor(
input_ids_seq_length, generation_config.min_new_tokens, generation_config.eos_token_id
input_ids_seq_length,
generation_config.min_new_tokens,
generation_config.eos_token_id,
device=device,
)
)
if prefix_allowed_tokens_fn is not None:
processors.append(
PrefixConstrainedLogitsProcessor(
prefix_allowed_tokens_fn, generation_config.num_beams // generation_config.num_beam_groups
prefix_allowed_tokens_fn,
generation_config.num_beams // generation_config.num_beam_groups,
)
)
if generation_config.forced_bos_token_id is not None:
processors.append(ForcedBOSTokenLogitsProcessor(generation_config.forced_bos_token_id))
processors.append(
ForcedBOSTokenLogitsProcessor(
generation_config.forced_bos_token_id,
)
)
if generation_config.forced_eos_token_id is not None:
processors.append(
ForcedEOSTokenLogitsProcessor(generation_config.max_length, generation_config.forced_eos_token_id)
ForcedEOSTokenLogitsProcessor(
generation_config.max_length,
generation_config.forced_eos_token_id,
device=device,
)
)
if generation_config.remove_invalid_values is True:
processors.append(InfNanRemoveLogitsProcessor())
......@@ -875,7 +903,12 @@ class GenerationMixin:
)
)
if generation_config.suppress_tokens is not None:
processors.append(SuppressTokensLogitsProcessor(generation_config.suppress_tokens))
processors.append(
SuppressTokensLogitsProcessor(
generation_config.suppress_tokens,
device=device,
)
)
if generation_config.begin_suppress_tokens is not None:
begin_index = input_ids_seq_length
begin_index = (
......@@ -887,7 +920,11 @@ class GenerationMixin:
# generation starts after the last token that is forced
begin_index += generation_config.forced_decoder_ids[-1][0]
processors.append(
SuppressTokensAtBeginLogitsProcessor(generation_config.begin_suppress_tokens, begin_index)
SuppressTokensAtBeginLogitsProcessor(
generation_config.begin_suppress_tokens,
begin_index,
device=device,
)
)
if generation_config.forced_decoder_ids is not None:
# TODO(Sanchit): deprecate in v4.40 by removing this logic
......@@ -1779,7 +1816,12 @@ class GenerationMixin:
# 12. prepare logits warper (if `do_sample` is `True`)
prepared_logits_warper = (
self._get_logits_warper(generation_config) if generation_config.do_sample else None
self._get_logits_warper(
generation_config,
device=input_ids.device,
)
if generation_config.do_sample
else None
)
# 13. run assisted generate
......@@ -1812,7 +1854,9 @@ class GenerationMixin:
elif generation_mode in (GenerationMode.SAMPLE, GenerationMode.GREEDY_SEARCH):
# 11. prepare logits warper
prepared_logits_warper = (
self._get_logits_warper(generation_config) if generation_config.do_sample else None
self._get_logits_warper(generation_config, device=input_ids.device)
if generation_config.do_sample
else None
)
# 12. expand input_ids with `num_return_sequences` additional sequences per batch
......@@ -1838,7 +1882,9 @@ class GenerationMixin:
elif generation_mode in (GenerationMode.BEAM_SAMPLE, GenerationMode.BEAM_SEARCH):
# 11. prepare logits warper
prepared_logits_warper = (
self._get_logits_warper(generation_config) if generation_config.do_sample else None
self._get_logits_warper(generation_config, device=input_ids.device)
if generation_config.do_sample
else None
)
# 12. prepare beam search scorer
......
......@@ -1729,6 +1729,7 @@ class MusicgenForCausalLM(MusicgenPreTrainedModel):
encoder_input_ids=input_ids,
prefix_allowed_tokens_fn=None,
logits_processor=logits_processor,
device=input_ids.device,
)
# 10. prepare stopping criteria
......@@ -1756,7 +1757,7 @@ class MusicgenForCausalLM(MusicgenPreTrainedModel):
elif is_sample_gen_mode:
# 11. prepare logits warper
logits_warper = self._get_logits_warper(generation_config)
logits_warper = self._get_logits_warper(generation_config, device=input_ids.device)
# expand input_ids with `num_return_sequences` additional sequences per batch
input_ids, model_kwargs = self._expand_inputs_for_generation(
......@@ -2822,6 +2823,7 @@ class MusicgenForConditionalGeneration(PreTrainedModel):
encoder_input_ids=inputs_tensor,
prefix_allowed_tokens_fn=None,
logits_processor=logits_processor,
device=input_ids.device,
)
# 10. prepare stopping criteria
......@@ -2849,7 +2851,7 @@ class MusicgenForConditionalGeneration(PreTrainedModel):
elif is_sample_gen_mode:
# 11. prepare logits warper
logits_warper = self._get_logits_warper(generation_config)
logits_warper = self._get_logits_warper(generation_config, device=input_ids.device)
# expand input_ids with `num_return_sequences` additional sequences per batch
input_ids, model_kwargs = self._expand_inputs_for_generation(
......
......@@ -1666,6 +1666,7 @@ class MusicgenMelodyForCausalLM(MusicgenMelodyPreTrainedModel):
encoder_input_ids=input_ids,
prefix_allowed_tokens_fn=None,
logits_processor=logits_processor,
device=input_ids.device,
)
# 10. prepare stopping criteria
......@@ -1693,7 +1694,7 @@ class MusicgenMelodyForCausalLM(MusicgenMelodyPreTrainedModel):
elif is_sample_gen_mode:
# 11. prepare logits warper
logits_warper = self._get_logits_warper(generation_config)
logits_warper = self._get_logits_warper(generation_config, device=input_ids.device)
# expand input_ids with `num_return_sequences` additional sequences per batch
input_ids, model_kwargs = self._expand_inputs_for_generation(
......@@ -2681,6 +2682,7 @@ class MusicgenMelodyForConditionalGeneration(PreTrainedModel):
encoder_input_ids=inputs_tensor,
prefix_allowed_tokens_fn=None,
logits_processor=logits_processor,
device=input_ids.device,
)
# 10. prepare stopping criteria
......@@ -2708,7 +2710,7 @@ class MusicgenMelodyForConditionalGeneration(PreTrainedModel):
elif is_sample_gen_mode:
# 11. prepare logits warper
logits_warper = self._get_logits_warper(generation_config)
logits_warper = self._get_logits_warper(generation_config, device=input_ids.device)
# expand input_ids with `num_return_sequences` additional sequences per batch
input_ids, model_kwargs = self._expand_inputs_for_generation(
......
......@@ -1538,6 +1538,7 @@ class RagTokenForGeneration(RagPreTrainedModel):
encoder_input_ids=context_input_ids,
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
logits_processor=logits_processor,
device=input_ids.device,
)
prepared_stopping_criteria = self._get_stopping_criteria(
......
......@@ -548,13 +548,15 @@ class WhisperGenerationMixin:
self._check_decoder_input_ids(kwargs=kwargs)
# 3. Retrieve logits processors
device = kwargs["encoder_outputs"][0].device if "encoder_outputs" in kwargs else input_features.device
begin_index = init_tokens.shape[1]
logits_processor = self._retrieve_logit_processors(
generation_config=generation_config,
logits_processor=logits_processor,
begin_index=begin_index, # begin index is index of first generated decoder token
is_shortform=is_shortform,
num_beams=generation_config.num_beams,
num_beams=kwargs.get("num_beams", 1),
device=device,
)
# 5. If we're in shortform mode, simple generate the whole input at once and return the output
......@@ -1400,7 +1402,9 @@ class WhisperGenerationMixin:
return max_frames, seek
def _retrieve_logit_processors(self, generation_config, logits_processor, begin_index, is_shortform, num_beams):
def _retrieve_logit_processors(
self, generation_config, logits_processor, begin_index, is_shortform, num_beams, device
):
if generation_config.return_timestamps is True:
timestamp_processor = WhisperTimeStampLogitsProcessor(generation_config, begin_index=begin_index)
logits_processor = (
......@@ -1408,7 +1412,7 @@ class WhisperGenerationMixin:
)
if generation_config.suppress_tokens is not None:
suppress_tokens_processor = SuppressTokensLogitsProcessor(generation_config.suppress_tokens)
suppress_tokens_processor = SuppressTokensLogitsProcessor(generation_config.suppress_tokens, device=device)
logits_processor = (
[suppress_tokens_processor]
if logits_processor is None
......@@ -1418,7 +1422,7 @@ class WhisperGenerationMixin:
if generation_config.begin_suppress_tokens is not None:
begin_suppress_processor = SuppressTokensAtBeginLogitsProcessor(
generation_config.begin_suppress_tokens, begin_index=begin_index
generation_config.begin_suppress_tokens, begin_index=begin_index, device=device
)
logits_processor = (
[begin_suppress_processor]
......
......@@ -69,7 +69,7 @@ class LogitsProcessorTest(unittest.TestCase):
batch_size = 4
eos_token_id = 0
min_dist_processor = MinLengthLogitsProcessor(min_length=10, eos_token_id=eos_token_id)
min_dist_processor = MinLengthLogitsProcessor(min_length=10, eos_token_id=eos_token_id, device=torch_device)
# check that min length is applied at length 5
input_ids = ids_tensor((batch_size, 5), vocab_size=20)
......@@ -91,7 +91,7 @@ class LogitsProcessorTest(unittest.TestCase):
# check that first input is skipped (min new length applying)
input_ids = ids_tensor((batch_size, 5), vocab_size=20)
new_min_dist_processor = MinNewTokensLengthLogitsProcessor(
prompt_length_to_skip=input_ids.shape[-1], min_new_tokens=3, eos_token_id=eos_token_id
prompt_length_to_skip=input_ids.shape[-1], min_new_tokens=3, eos_token_id=eos_token_id, device=torch_device
)
expected_eos_scores_before_min_length = batch_size * [-float("inf")]
......@@ -450,7 +450,7 @@ class LogitsProcessorTest(unittest.TestCase):
torch.tensor([[0.0, 0.1, 0.8, 0.1], [0.01, 0.04, 0.9, 0.05]], device=torch_device, dtype=torch.float)
)
eta_warp = EtaLogitsWarper(0.0625)
eta_warp = EtaLogitsWarper(0.0625, device=torch_device)
filtered_dist = torch.exp(eta_warp(input_ids, dist))
# dist should be filtered to only keep values with proba >= min(0.0625, sqrt(0.0625) * e^-H(p))
......@@ -474,7 +474,7 @@ class LogitsProcessorTest(unittest.TestCase):
ramp_logits[1] = ramp_logits[1] * 100.0
# make sure at least 2 tokens are kept
eta_warp = EtaLogitsWarper(0.1, min_tokens_to_keep=2, filter_value=0.0)
eta_warp = EtaLogitsWarper(0.1, min_tokens_to_keep=2, filter_value=0.0, device=torch_device)
filtered_dist = eta_warp(input_ids, ramp_logits)
# first batch should keep 2 tokens, second batch would keep only 1, but due to `min_tokens_to_keep=2` keeps 2.
......@@ -640,7 +640,7 @@ class LogitsProcessorTest(unittest.TestCase):
scores_comp = scores.clone()
# instantiate all dist processors
min_dist_proc = MinLengthLogitsProcessor(min_length=10, eos_token_id=eos_token_id)
min_dist_proc = MinLengthLogitsProcessor(min_length=10, eos_token_id=eos_token_id, device=torch_device)
temp_dist_warp = TemperatureLogitsWarper(temperature=0.5)
rep_penalty_proc = RepetitionPenaltyLogitsProcessor(penalty=2.0)
top_k_warp = TopKLogitsWarper(3)
......@@ -767,7 +767,9 @@ class LogitsProcessorTest(unittest.TestCase):
eos_token_id = 0
max_length = 5
logits_processor = ForcedEOSTokenLogitsProcessor(max_length=max_length, eos_token_id=eos_token_id)
logits_processor = ForcedEOSTokenLogitsProcessor(
max_length=max_length, eos_token_id=eos_token_id, device=torch_device
)
# check that all scores are -inf except the eos_token_id when max_length-1 is reached
input_ids = ids_tensor((batch_size, 4), vocab_size=20)
......@@ -927,7 +929,7 @@ class LogitsProcessorTest(unittest.TestCase):
scores = self._get_uniform_logits(2, 4)
scores[0][eos_token_id] = -6 ## less than log(min_eos_p)
esp = BarkEosPrioritizerLogitsProcessor(eos_token_id=eos_token_id, min_eos_p=min_eos_p)
esp = BarkEosPrioritizerLogitsProcessor(eos_token_id=eos_token_id, min_eos_p=min_eos_p, device=torch_device)
actual_scores = esp(input_ids, scores)
expected_scores_list = [
scores[0].tolist(),
......@@ -943,7 +945,7 @@ class LogitsProcessorTest(unittest.TestCase):
scores = self._get_uniform_logits(2, 4)
scores[0][eos_token_id] = -6 ## less than log(min_eos_p)
esp = BarkEosPrioritizerLogitsProcessor(eos_token_id=eos_token_id, min_eos_p=min_eos_p)
esp = BarkEosPrioritizerLogitsProcessor(eos_token_id=eos_token_id, min_eos_p=min_eos_p, device=torch_device)
actual_scores = esp(input_ids, scores)
expected_scores_list = [
scores[0].tolist(),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment