INFERENCE.md 8.49 KB
Newer Older
Yoach Lacombe's avatar
Yoach Lacombe committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
# Inference tips

Parler-TTS benefits from a number of optimizations that can make the model up to 4x faster. Add to this the ability to stream audio as it's being generated, and you can achieve time-to-first audio in under 500ms on a modern GPU.

## 📖 Quick Index
* [Efficient Attention Implementation](#efficient-attention-implementations)
* [Compilation](#compilation)
* [Streaming](#streaming)
* [Batch generation](#batch-generation)

## Efficient Attention implementations

Parler-TTS supports [SDPA](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html) and [Flash Attention 2](https://github.com/Dao-AILab/flash-attention).  

SDPA is used by default and speeds up generation time by up to 1.4x compared with eager attention.

To switch between attention implementations, simply specify `attn_implementation=attn_implementation` when loading the checkpoints:

```py
from parler_tts import ParlerTTSForConditionalGeneration

torch_device = "cuda:0" # use "mps" for Mac
torch_dtype = torch.bfloat16
model_name = "parler-tts/parler-tts-mini-v1"

attn_implementation = "eager" # "sdpa" or "flash_attention_2"

model = ParlerTTSForConditionalGeneration.from_pretrained(
    model_name,
    attn_implementation=attn_implementation
).to(torch_device, dtype=torch_dtype)
```

## Compilation

[Compiling](https://pytorch.org/docs/stable/generated/torch.compile.html) the forward method of Parler can speed up generation time by up to 4.5x.

As an indication, `mode=default` brings a speed-up of 1.4 times compared to no compilation, while `mode="reduce-overhead"` brings much faster generation, at the cost of a longer compilation time and the need to generate twice to see the benefits of compilation.

```py
import torch
from parler_tts import ParlerTTSForConditionalGeneration
from transformers import AutoTokenizer

torch_device = "cuda:0"
torch_dtype = torch.bfloat16
model_name = "parler-tts/parler-tts-mini-v1"

# need to set padding max length
max_length = 50

# load model and tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name) 
model = ParlerTTSForConditionalGeneration.from_pretrained(
    model_name,
    attn_implementation="eager"
).to(torch_device, dtype=torch_dtype)

# compile the forward pass
compile_mode = "default" # chose "reduce-overhead" for 3 to 4x speed-up
model.generation_config.cache_implementation = "static"
model.forward = torch.compile(model.forward, mode=compile_mode)

# warmup
65
inputs = tokenizer("This is for compilation", return_tensors="pt", padding="max_length", max_length=max_length).to(torch_device)
Yoach Lacombe's avatar
Yoach Lacombe committed
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202

model_kwargs = {**inputs, "prompt_input_ids": inputs.input_ids, "prompt_attention_mask": inputs.attention_mask, }

n_steps = 1 if compile_mode == "default" else 2
for _ in range(n_steps):
    _ = model.generate(**model_kwargs)


# now you can benefit from compilation speed-ups
...

```


## Streaming

### How Does It Work?

Parler-TTS is an auto-regressive transformer-based model, meaning generates audio codes (tokens) in a causal fashion.

At each decoding step, the model generates a new set of audio codes, conditional on the text input and all previous audio codes. From the 
frame rate of the [DAC model](https://huggingface.co/parler-tts/dac_44khZ_8kbps) used to decode the generated codes to audio waveform,  each set of generated audio codes corresponds to 0.011 seconds. This means we require a total of 1720 decoding steps to generate 20 seconds of audio.

Rather than waiting for the entire audio sequence to be generated, which would require the full 1720 decoding steps, we can start playing the audio after a specified number of decoding steps have been reached, a techinque known as [*streaming*](https://huggingface.co/docs/transformers/main/en/generation_strategies#streaming). 
For example, after 86 steps we have the first second of audio ready, and so can play this without waiting for the remaining decoding steps to be complete. As we continue to generate with the Parler-TTS model, we append new chunks of generated audio to our output waveform on-the-fly. After the full 1720 decoding steps, the generated audio is complete, and is composed of 20 chunks of audio, each corresponding to 86 tokens.
This method of playing incremental generations reduces the latency of the Parler-TTS model from the total time to generate 1720 tokens, to the time taken to play the first chunk of audio (86 tokens). This can result in significant improvements to perceived latency,  particularly when the chunk size is chosen to be small. In practice, the chunk size should be tuned to your device: using a smaller chunk size will mean that the first chunk is ready faster, but should not be chosen so small that the model generates slower than the time it takes to play the audio.


### How Can I Use It?

We've added [ParlerTTSStreamer](https://github.com/huggingface/parler-tts/blob/main/parler_tts/streamer.py) to the library. Don't hesitate to adapt it to your use-case.

Here's how to create a generator out of the streamer.

```py
import torch
from parler_tts import ParlerTTSForConditionalGeneration, ParlerTTSStreamer
from transformers import AutoTokenizer
from threading import Thread

torch_device = "cuda:0" # Use "mps" for Mac 
torch_dtype = torch.bfloat16
model_name = "parler-tts/parler-tts-mini-v1"

# need to set padding max length
max_length = 50

# load model and tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name) 
model = ParlerTTSForConditionalGeneration.from_pretrained(
    model_name,
).to(torch_device, dtype=torch_dtype)

sampling_rate = model.audio_encoder.config.sampling_rate
frame_rate = model.audio_encoder.config.frame_rate

def generate(text, description, play_steps_in_s=0.5):
  play_steps = int(frame_rate * play_steps_in_s)
  streamer = ParlerTTSStreamer(model, device=torch_device, play_steps=play_steps)
  # tokenization
  inputs = tokenizer(description, return_tensors="pt").to(torch_device)
  prompt = tokenizer(text, return_tensors="pt").to(torch_device)
  # create generation kwargs
  generation_kwargs = dict(
    input_ids=inputs.input_ids,
    prompt_input_ids=prompt.input_ids,
    attention_mask=inputs.attention_mask,
    prompt_attention_mask=prompt.attention_mask,
    streamer=streamer,
    do_sample=True,
    temperature=1.0,
    min_new_tokens=10,
  )
  # initialize Thread
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
  thread.start()
  # iterate over chunks of audio
  for new_audio in streamer:
    if new_audio.shape[0] == 0:
      break
    print(f"Sample of length: {round(new_audio.shape[0] / sampling_rate, 4)} seconds")
    yield sampling_rate, new_audio


# now you can do
text = "This is a test of the streamer class"
description = "Jon's talking really fast."

chunk_size_in_s = 0.5

for (sampling_rate, audio_chunk) in generate(text, description, chunk_size_in_s):
  # You can do everything that you need with the chunk now
  # For example: stream it, save it, play it.
  print(audio_chunk.shape) 
```

## Batch generation

Batching means combining operations for multiple samples to bring the overall time spent generating the samples lower than generating sample per sample.

Here is a quick example of how you can use it:

```py
from parler_tts import ParlerTTSForConditionalGeneration
from transformers import AutoTokenizer, AutoFeatureExtractor, set_seed
import scipy


repo_id = "parler-tts/parler-tts-mini-v1"

model = ParlerTTSForConditionalGeneration.from_pretrained(repo_id).to("cuda")
tokenizer = AutoTokenizer.from_pretrained(repo_id, padding_side="left")
feature_extractor = AutoFeatureExtractor.from_pretrained(repo_id)

input_text = ["Hey, how are you doing?", "I'm not sure how to feel about it."]
description = 2 * ["A male speaker with a monotone and high-pitched voice is delivering his speech at a really low speed in a confined environment."]

inputs = tokenizer(description, return_tensors="pt", padding=True).to("cuda")
prompt = tokenizer(input_text, return_tensors="pt", padding=True).to("cuda")

set_seed(0)
generation = model.generate(
    input_ids=inputs.input_ids,
    attention_mask=inputs.attention_mask,
    prompt_input_ids=prompt.input_ids,
    prompt_attention_mask=prompt.attention_mask,
    do_sample=True,
    return_dict_in_generate=True,
)

audio_1 = generation.sequences[0, :generation.audios_length[0]]
audio_2 = generation.sequences[1, :generation.audios_length[1]]

print(audio_1.shape, audio_2.shape)
scipy.io.wavfile.write("sample_out.wav", rate=feature_extractor.sampling_rate, data=audio_1.cpu().numpy().squeeze())
scipy.io.wavfile.write("sample_out_2.wav", rate=feature_extractor.sampling_rate, data=audio_2.cpu().numpy().squeeze())
```