Unverified Commit ac5d4cf6 authored by Yoach Lacombe's avatar Yoach Lacombe Committed by GitHub
Browse files

FIx Bark batching feature (#27271)

* fix bark batching

* make style

* add tests and make style
parent 8f840edd
...@@ -909,8 +909,9 @@ class BarkCoarseModel(BarkCausalModel): ...@@ -909,8 +909,9 @@ class BarkCoarseModel(BarkCausalModel):
coarse_generation_config: BarkCoarseGenerationConfig = None, coarse_generation_config: BarkCoarseGenerationConfig = None,
codebook_size: int = 1024, codebook_size: int = 1024,
history_prompt: Optional[Dict[str, torch.Tensor]] = None, history_prompt: Optional[Dict[str, torch.Tensor]] = None,
return_output_lengths: Optional[bool] = None,
**kwargs, **kwargs,
) -> torch.LongTensor: ) -> Union[torch.LongTensor, Tuple[torch.LongTensor, torch.LongTensor]]:
""" """
Generates coarse acoustics tokens from input text semantic tokens and an additional optional `Bark` speaker Generates coarse acoustics tokens from input text semantic tokens and an additional optional `Bark` speaker
prompt. prompt.
...@@ -926,8 +927,14 @@ class BarkCoarseModel(BarkCausalModel): ...@@ -926,8 +927,14 @@ class BarkCoarseModel(BarkCausalModel):
Codebook channel size, i.e. the size of the output vocabulary per codebook channel. Codebook channel size, i.e. the size of the output vocabulary per codebook channel.
history_prompt (`Optional[Dict[str,torch.Tensor]]`, *optional*): history_prompt (`Optional[Dict[str,torch.Tensor]]`, *optional*):
Optional `Bark` speaker prompt. Optional `Bark` speaker prompt.
return_output_lengths (`bool`, *optional*):
Whether or not to return the output lengths. Useful when batching.
Returns: Returns:
torch.LongTensor: Output coarse acoustics tokens. By default:
torch.LongTensor: Output coarse acoustics tokens.
If `return_output_lengths=True`:
`Tuple(torch.Tensor, torch.Tensor): The output coarse acoustics tokens, and the length of each sample
of the batch.
""" """
if semantic_generation_config is None: if semantic_generation_config is None:
...@@ -954,13 +961,13 @@ class BarkCoarseModel(BarkCausalModel): ...@@ -954,13 +961,13 @@ class BarkCoarseModel(BarkCausalModel):
) )
max_semantic_history = int(np.floor(max_coarse_history / semantic_to_coarse_ratio)) max_semantic_history = int(np.floor(max_coarse_history / semantic_to_coarse_ratio))
# beware, depends on the seq_len of the longest sequence of the batch. output_lengths = (semantic_output != coarse_generation_config.coarse_semantic_pad_token).sum(1)
# Also, the seq_len might be one token too long because of an added output_lengths = torch.floor(
# pad_token as compared to Bark original implementation. output_lengths * semantic_to_coarse_ratio / coarse_generation_config.n_coarse_codebooks
max_generated_len = np.floor(
semantic_output.shape[1] * semantic_to_coarse_ratio / coarse_generation_config.n_coarse_codebooks
) )
max_generated_len = int(round(max_generated_len * coarse_generation_config.n_coarse_codebooks)) output_lengths = torch.round(output_lengths * coarse_generation_config.n_coarse_codebooks).int()
max_generated_len = torch.max(output_lengths).item()
batch_size = semantic_output.shape[0] batch_size = semantic_output.shape[0]
...@@ -1026,6 +1033,9 @@ class BarkCoarseModel(BarkCausalModel): ...@@ -1026,6 +1033,9 @@ class BarkCoarseModel(BarkCausalModel):
coarse_output = x_coarse[:, len_coarse_history:] coarse_output = x_coarse[:, len_coarse_history:]
if return_output_lengths:
return coarse_output, output_lengths
return coarse_output return coarse_output
...@@ -1502,13 +1512,21 @@ class BarkModel(BarkPreTrainedModel): ...@@ -1502,13 +1512,21 @@ class BarkModel(BarkPreTrainedModel):
# We'll offload the last model manually. # We'll offload the last model manually.
self.codec_model_hook = hook self.codec_model_hook = hook
def codec_decode(self, fine_output): def codec_decode(self, fine_output, output_lengths=None):
"""Turn quantized audio codes into audio array using encodec.""" """Turn quantized audio codes into audio array using encodec."""
fine_output = fine_output.transpose(0, 1) fine_output = fine_output.transpose(0, 1)
emb = self.codec_model.quantizer.decode(fine_output) emb = self.codec_model.quantizer.decode(fine_output)
out = self.codec_model.decoder(emb)
audio_arr = out.squeeze(1) # squeeze the codebook dimension if output_lengths is not None:
# encodec uses LSTMs which behaves differently with appended padding
# decoding with encodec takes around 0.1% of the total generation time
# to keep generation quality, we break batching
out = [sample[:, :l].unsqueeze(0) for (sample, l) in zip(emb, output_lengths)]
audio_arr = [self.codec_model.decoder(sample).squeeze() for sample in out]
else:
out = self.codec_model.decoder(emb)
audio_arr = out.squeeze(1) # squeeze the codebook dimension
return audio_arr return audio_arr
...@@ -1517,6 +1535,7 @@ class BarkModel(BarkPreTrainedModel): ...@@ -1517,6 +1535,7 @@ class BarkModel(BarkPreTrainedModel):
self, self,
input_ids: Optional[torch.Tensor] = None, input_ids: Optional[torch.Tensor] = None,
history_prompt: Optional[Dict[str, torch.Tensor]] = None, history_prompt: Optional[Dict[str, torch.Tensor]] = None,
return_output_lengths: Optional[bool] = None,
**kwargs, **kwargs,
) -> torch.LongTensor: ) -> torch.LongTensor:
""" """
...@@ -1535,9 +1554,15 @@ class BarkModel(BarkPreTrainedModel): ...@@ -1535,9 +1554,15 @@ class BarkModel(BarkPreTrainedModel):
semantic, coarse and fine respectively. It has the priority over the keywords without a prefix. semantic, coarse and fine respectively. It has the priority over the keywords without a prefix.
This means you can, for example, specify a generation strategy for all sub-models except one. This means you can, for example, specify a generation strategy for all sub-models except one.
return_output_lengths (`bool`, *optional*):
Whether or not to return the waveform lengths. Useful when batching.
Returns: Returns:
torch.LongTensor: Output generated audio. By default:
- **audio_waveform** (`torch.Tensor` of shape (batch_size, seq_len)): Generated audio waveform.
When `return_output_lengths=True`:
Returns a tuple made of:
- **audio_waveform** (`torch.Tensor` of shape (batch_size, seq_len)): Generated audio waveform.
- **output_lengths** (`torch.Tensor` of shape (batch_size)): The length of each waveform in the batch
Example: Example:
```python ```python
...@@ -1603,9 +1628,16 @@ class BarkModel(BarkPreTrainedModel): ...@@ -1603,9 +1628,16 @@ class BarkModel(BarkPreTrainedModel):
semantic_generation_config=semantic_generation_config, semantic_generation_config=semantic_generation_config,
coarse_generation_config=coarse_generation_config, coarse_generation_config=coarse_generation_config,
codebook_size=self.generation_config.codebook_size, codebook_size=self.generation_config.codebook_size,
return_output_lengths=return_output_lengths,
**kwargs_coarse, **kwargs_coarse,
) )
output_lengths = None
if return_output_lengths:
coarse_output, output_lengths = coarse_output
# (batch_size, seq_len*coarse_codebooks) -> (batch_size, seq_len)
output_lengths = output_lengths // coarse_generation_config.n_coarse_codebooks
# 3. "generate" from the fine model # 3. "generate" from the fine model
output = self.fine_acoustics.generate( output = self.fine_acoustics.generate(
coarse_output, coarse_output,
...@@ -1625,10 +1657,15 @@ class BarkModel(BarkPreTrainedModel): ...@@ -1625,10 +1657,15 @@ class BarkModel(BarkPreTrainedModel):
self.codec_model = self.codec_model.to(self.device) self.codec_model = self.codec_model.to(self.device)
# 4. Decode the output and generate audio array # 4. Decode the output and generate audio array
audio = self.codec_decode(output) audio = self.codec_decode(output, output_lengths)
if getattr(self, "codec_model_hook", None) is not None: if getattr(self, "codec_model_hook", None) is not None:
# Offload codec_model to CPU # Offload codec_model to CPU
self.codec_model_hook.offload() self.codec_model_hook.offload()
if return_output_lengths:
output_lengths = [len(sample) for sample in audio]
audio = nn.utils.rnn.pad_sequence(audio, batch_first=True, padding_value=0)
return audio, output_lengths
return audio return audio
...@@ -1067,6 +1067,37 @@ class BarkModelIntegrationTests(unittest.TestCase): ...@@ -1067,6 +1067,37 @@ class BarkModelIntegrationTests(unittest.TestCase):
self.model.generate(**input_ids, do_sample=True, temperature=0.6, penalty_alpha=0.6) self.model.generate(**input_ids, do_sample=True, temperature=0.6, penalty_alpha=0.6)
self.model.generate(**input_ids, do_sample=True, temperature=0.6, num_beams=4) self.model.generate(**input_ids, do_sample=True, temperature=0.6, num_beams=4)
@slow
def test_generate_batching(self):
args = {"do_sample": False, "temperature": None}
s1 = "I love HuggingFace"
s2 = "In the light of the moon, a little egg lay on a leaf"
voice_preset = "en_speaker_6"
input_ids = self.processor([s1, s2], voice_preset=voice_preset).to(torch_device)
# generate in batch
outputs, audio_lengths = self.model.generate(**input_ids, **args, return_output_lengths=True)
# generate one-by-one
s1 = self.processor(s1, voice_preset=voice_preset).to(torch_device)
s2 = self.processor(s2, voice_preset=voice_preset).to(torch_device)
output1 = self.model.generate(**s1, **args)
output2 = self.model.generate(**s2, **args)
# up until the coarse acoustic model (included), results are the same
# the fine acoustic model introduces small differences
# first verify if same length (should be the same because it's decided in the coarse model)
self.assertEqual(tuple(audio_lengths), (output1.shape[1], output2.shape[1]))
# then assert almost equal
self.assertTrue(torch.allclose(outputs[0, : audio_lengths[0]], output1.squeeze(), atol=2e-3))
self.assertTrue(torch.allclose(outputs[1, : audio_lengths[1]], output2.squeeze(), atol=2e-3))
# now test single input with return_output_lengths = True
outputs, _ = self.model.generate(**s1, **args, return_output_lengths=True)
self.assertTrue((outputs == output1).all().item())
@slow @slow
def test_generate_end_to_end_with_sub_models_args(self): def test_generate_end_to_end_with_sub_models_args(self):
input_ids = self.inputs input_ids = self.inputs
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment