"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "9129fd0377e4d46cb2d0ea28dc1eb91a15f65b77"
Unverified Commit e2bffcfa authored by Isaac Chung's avatar Isaac Chung Committed by GitHub
Browse files

Add early stopping for Bark generation via logits processor (#26675)

* add early stopping logits processor

* black formmated

* indent

* follow method signature

* actual logic

* check for None

* address comments on docstrings and method signature

* add unit test under `LogitsProcessorTest` wip

* unit test passing

* black formatted

* condition per sample

* add to BarkModelIntegrationTests

* wip BarkSemanticModelTest

* rename and add to kwargs handling

* not add to BarkSemanticModelTest

* correct logic and assert last outputs tokens different in test

* doc-builder style

* read from kwargs as well

* assert len of with less than that of without

* ruff

* add back seed and test case

* add original impl default suggestion

* doc-builder

* rename and use softmax

* switch back to LogitsProcessor and update docs wording

* camelCase and spelling and saving compute

* assert strictly less than

* assert less than

* expand test_generate_semantic_early_stop instead
parent 90ee9cea
...@@ -1749,3 +1749,35 @@ class UnbatchedClassifierFreeGuidanceLogitsProcessor(LogitsProcessor): ...@@ -1749,3 +1749,35 @@ class UnbatchedClassifierFreeGuidanceLogitsProcessor(LogitsProcessor):
unconditional_logits = torch.nn.functional.log_softmax(logits[:, -1], dim=-1) unconditional_logits = torch.nn.functional.log_softmax(logits[:, -1], dim=-1)
out = self.guidance_scale * (scores - unconditional_logits) + unconditional_logits out = self.guidance_scale * (scores - unconditional_logits) + unconditional_logits
return out return out
class BarkEosPrioritizerLogitsProcessor(LogitsProcessor):
r"""This processor ensures that the EOS token is selected if its probability is greater than the `min_eos_p`.
Args:
eos_token_id (`Union[int, List[int]]`):
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
min_eos_p (`float`, *optional*):
Minimum end of speech threshold.
"""
def __init__(self, eos_token_id: Union[int, List[int]], min_eos_p: float):
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
self.eos_token_id = eos_token_id
if min_eos_p is not None and min_eos_p <= 0:
raise ValueError(f"`min_eos_p` has to be a positive float, but is {min_eos_p}")
self.min_eos_p = min_eos_p
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
if self.min_eos_p:
probs = torch.nn.functional.softmax(scores.float(), dim=-1)
# create scores full of -inf except for the eos_token_id
early_stop_scores = torch.ones_like(scores) * -float("inf")
early_stop_scores[:, self.eos_token_id] = scores[:, self.eos_token_id]
do_early_stop = probs[:, self.eos_token_id] > self.min_eos_p
scores = torch.where(do_early_stop, early_stop_scores, scores)
return scores
...@@ -44,6 +44,7 @@ class BarkSemanticGenerationConfig(GenerationConfig): ...@@ -44,6 +44,7 @@ class BarkSemanticGenerationConfig(GenerationConfig):
semantic_vocab_size=10_000, semantic_vocab_size=10_000,
max_input_semantic_length=256, max_input_semantic_length=256,
semantic_rate_hz=49.9, semantic_rate_hz=49.9,
min_eos_p=None,
**kwargs, **kwargs,
): ):
"""Class that holds a generation configuration for [`BarkSemanticModel`]. """Class that holds a generation configuration for [`BarkSemanticModel`].
...@@ -86,6 +87,10 @@ class BarkSemanticGenerationConfig(GenerationConfig): ...@@ -86,6 +87,10 @@ class BarkSemanticGenerationConfig(GenerationConfig):
Max length of semantic input vector. Max length of semantic input vector.
semantic_rate_hz (`float`, *optional*, defaults to 49.9): semantic_rate_hz (`float`, *optional*, defaults to 49.9):
Semantic rate in Hertz. Semantic rate in Hertz.
min_eos_p (`float`, *optional*):
Minimum threshold of the probability of the EOS token for it to be sampled. This is an early stopping
strategy to mitigate potential unwanted generations at the end of a prompt. The original implementation
suggests a default value of 0.2.
""" """
super().__init__( super().__init__(
temperature=temperature, temperature=temperature,
...@@ -107,6 +112,7 @@ class BarkSemanticGenerationConfig(GenerationConfig): ...@@ -107,6 +112,7 @@ class BarkSemanticGenerationConfig(GenerationConfig):
self.semantic_vocab_size = semantic_vocab_size self.semantic_vocab_size = semantic_vocab_size
self.max_input_semantic_length = max_input_semantic_length self.max_input_semantic_length = max_input_semantic_length
self.semantic_rate_hz = semantic_rate_hz self.semantic_rate_hz = semantic_rate_hz
self.min_eos_p = min_eos_p
class BarkCoarseGenerationConfig(GenerationConfig): class BarkCoarseGenerationConfig(GenerationConfig):
......
...@@ -21,7 +21,11 @@ import torch ...@@ -21,7 +21,11 @@ import torch
from torch import nn from torch import nn
from torch.nn import functional as F from torch.nn import functional as F
from ...generation.logits_process import AlternatingCodebooksLogitsProcessor, SuppressTokensLogitsProcessor from ...generation.logits_process import (
AlternatingCodebooksLogitsProcessor,
BarkEosPrioritizerLogitsProcessor,
SuppressTokensLogitsProcessor,
)
from ...modeling_outputs import CausalLMOutputWithPast, MaskedLMOutput from ...modeling_outputs import CausalLMOutputWithPast, MaskedLMOutput
from ...modeling_utils import PreTrainedModel, get_parameter_device from ...modeling_utils import PreTrainedModel, get_parameter_device
from ...utils import ( from ...utils import (
...@@ -798,12 +802,17 @@ class BarkSemanticModel(BarkCausalModel): ...@@ -798,12 +802,17 @@ class BarkSemanticModel(BarkCausalModel):
suppress_tokens_logits_processor = SuppressTokensLogitsProcessor(tokens_to_suppress) suppress_tokens_logits_processor = SuppressTokensLogitsProcessor(tokens_to_suppress)
min_eos_p = kwargs.get("min_eos_p", semantic_generation_config.min_eos_p)
early_stopping_logits_processor = BarkEosPrioritizerLogitsProcessor(
eos_token_id=semantic_generation_config.eos_token_id, min_eos_p=min_eos_p
)
# pass input_ids in order to stay consistent with the transformers generate method even though it is not used # pass input_ids in order to stay consistent with the transformers generate method even though it is not used
# (except to get the input seq_len - that's why we keep the first 257 tokens) # (except to get the input seq_len - that's why we keep the first 257 tokens)
semantic_output = super().generate( semantic_output = super().generate(
torch.ones((batch_size, max_input_semantic_length + 1), dtype=torch.int).to(self.device), torch.ones((batch_size, max_input_semantic_length + 1), dtype=torch.int).to(self.device),
input_embeds=input_embeds, input_embeds=input_embeds,
logits_processor=[suppress_tokens_logits_processor], logits_processor=[suppress_tokens_logits_processor, early_stopping_logits_processor],
generation_config=semantic_generation_config, generation_config=semantic_generation_config,
**kwargs, **kwargs,
) # size: 10048 ) # size: 10048
...@@ -1559,7 +1568,8 @@ class BarkModel(BarkPreTrainedModel): ...@@ -1559,7 +1568,8 @@ class BarkModel(BarkPreTrainedModel):
kwargs_semantic = { kwargs_semantic = {
# if "attention_mask" is set, it should not be passed to CoarseModel and FineModel # if "attention_mask" is set, it should not be passed to CoarseModel and FineModel
"attention_mask": kwargs.pop("attention_mask", None) "attention_mask": kwargs.pop("attention_mask", None),
"min_eos_p": kwargs.pop("min_eos_p", None),
} }
kwargs_coarse = {} kwargs_coarse = {}
kwargs_fine = {} kwargs_fine = {}
......
...@@ -53,6 +53,7 @@ if is_torch_available(): ...@@ -53,6 +53,7 @@ if is_torch_available():
TypicalLogitsWarper, TypicalLogitsWarper,
UnbatchedClassifierFreeGuidanceLogitsProcessor, UnbatchedClassifierFreeGuidanceLogitsProcessor,
) )
from transformers.generation.logits_process import BarkEosPrioritizerLogitsProcessor
@require_torch @require_torch
...@@ -800,3 +801,19 @@ class LogitsProcessorTest(unittest.TestCase): ...@@ -800,3 +801,19 @@ class LogitsProcessorTest(unittest.TestCase):
self.assertAlmostEqual(out[0].item(), res[0].item()) self.assertAlmostEqual(out[0].item(), res[0].item())
self.assertAlmostEqual(out[1].item(), res[1].item()) self.assertAlmostEqual(out[1].item(), res[1].item())
self.assertAlmostEqual(out[2].item(), res[2].item()) self.assertAlmostEqual(out[2].item(), res[2].item())
def test_early_stop_processor(self):
input_ids = None
eos_token_id = 2
min_eos_p = 0.1 ## some small float
scores = self._get_uniform_logits(2, 4)
scores[0][eos_token_id] = -6 ## less than log(min_eos_p)
esp = BarkEosPrioritizerLogitsProcessor(eos_token_id=eos_token_id, min_eos_p=min_eos_p)
actual_scores = esp(input_ids, scores)
expected_scores_list = [
scores[0].tolist(),
[float("-inf"), float("-inf"), scores[0][0], float("-inf")],
]
self.assertListEqual(actual_scores.tolist(), expected_scores_list)
...@@ -917,7 +917,51 @@ class BarkModelIntegrationTests(unittest.TestCase): ...@@ -917,7 +917,51 @@ class BarkModelIntegrationTests(unittest.TestCase):
temperature=1.0, temperature=1.0,
semantic_generation_config=self.semantic_generation_config, semantic_generation_config=self.semantic_generation_config,
) )
self.assertListEqual(output_ids[0, : len(expected_output_ids)].tolist(), expected_output_ids)
@slow
def test_generate_semantic_early_stop(self):
input_ids = self.inputs
min_eos_p = 0.01
# fmt: off
# check first ids
expected_output_ids = [7363, 321, 41, 1461, 6915, 952, 326, 41, 41, 927,]
# fmt: on
# Should be able to read min_eos_p from kwargs
with torch.no_grad():
torch.manual_seed(0)
output_ids_without_min_eos_p = self.model.semantic.generate(
**input_ids,
do_sample=False,
temperature=0.9,
semantic_generation_config=self.semantic_generation_config,
)
torch.manual_seed(0)
output_ids_kwargs = self.model.semantic.generate(
**input_ids,
do_sample=False,
temperature=0.9,
semantic_generation_config=self.semantic_generation_config,
min_eos_p=min_eos_p,
)
self.assertListEqual(output_ids_without_min_eos_p[0, : len(expected_output_ids)].tolist(), expected_output_ids)
self.assertLess(len(output_ids_kwargs[0, :].tolist()), len(output_ids_without_min_eos_p[0, :].tolist()))
# Should be able to read min_eos_p from the semantic generation config
self.semantic_generation_config.min_eos_p = min_eos_p
with torch.no_grad():
torch.manual_seed(0)
output_ids = self.model.semantic.generate(
**input_ids,
do_sample=False,
temperature=0.9,
semantic_generation_config=self.semantic_generation_config,
)
self.assertEqual(output_ids.shape, output_ids_kwargs.shape)
self.assertLess(len(output_ids[0, :].tolist()), len(output_ids_without_min_eos_p[0, :].tolist()))
self.assertListEqual(output_ids[0, : len(expected_output_ids)].tolist(), expected_output_ids) self.assertListEqual(output_ids[0, : len(expected_output_ids)].tolist(), expected_output_ids)
@slow @slow
...@@ -1022,26 +1066,30 @@ class BarkModelIntegrationTests(unittest.TestCase): ...@@ -1022,26 +1066,30 @@ class BarkModelIntegrationTests(unittest.TestCase):
input_ids = self.inputs input_ids = self.inputs
with torch.no_grad(): with torch.no_grad():
torch.manual_seed(0)
self.model.generate( self.model.generate(
**input_ids, do_sample=False, temperature=1.0, coarse_do_sample=True, coarse_temperature=0.7 **input_ids, do_sample=False, temperature=1.0, coarse_do_sample=True, coarse_temperature=0.7
) )
self.model.generate( output_ids_without_min_eos_p = self.model.generate(
**input_ids, **input_ids,
do_sample=False, do_sample=True,
temperature=1.0, temperature=0.9,
coarse_do_sample=True, coarse_do_sample=True,
coarse_temperature=0.7, coarse_temperature=0.7,
fine_temperature=0.3, fine_temperature=0.3,
) )
self.model.generate(
output_ids_with_min_eos_p = self.model.generate(
**input_ids, **input_ids,
do_sample=True, do_sample=True,
temperature=0.6, temperature=0.9,
penalty_alpha=0.6, coarse_temperature=0.7,
semantic_temperature=0.9, fine_temperature=0.3,
coarse_temperature=0.2, min_eos_p=0.1,
fine_temperature=0.1,
) )
self.assertLess(
len(output_ids_with_min_eos_p[0, :].tolist()), len(output_ids_without_min_eos_p[0, :].tolist())
)
@require_torch_gpu @require_torch_gpu
@slow @slow
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment