Replace consecutive `num_spk_embs` speaker embedding placeholders in input_embeds with pre-prepared speaker embeddings. This is an in-place replacement, no new tensor is created, so no value is returned.
Args:
input_ids (torch.Tensor): Input ID tensor, shape [batch_size, seq_len_max]
audio_input_ids (torch.LongTensor): Updated input IDs including condition and generated audio codes, shape (batch_size, full_sequence_length, num_vq).
past_key_values (Tuple[Tuple[torch.FloatTensor]]): Tuple containing pre-computed keys and values used for attention mechanism. Each element has shape (batch_size, num_heads, sequence_length, embed_size_per_head).
finished (bool): Boolean indicating whether generation is complete.
"""A conditional text-to-speech model that can generate speech from text with speaker conditioning.
This model extends PreTrainedModel to provide text-to-speech capabilities with:
- LLM hidden state conditioning
- Streaming generation
The model uses a transformer architecture with LLM hidden states and can operate in both
streaming and non-streaming modes for flexible deployment.
The model process sequence in the following format:
| text bos token | LLM embedding projected to tts embedding space | text tokens (fixed length, reserved for future tokens) | audio bos token | audio tokens (audio token length is not fixed)| audio eos token |
The format is designed to support LLM-conditioned streaming audio generation.
Usage:
To support streaming generation, two global variables should be maintained outside of the model.
1. `audio_input_ids`: stores *discrete* audio codes. It is a tensor with shape [1, sequence length+1, num_vq].
2. `past_key_values`: stores the KV cache for both text tokens and audio codes. It is a list of tuples, each tuple contains two tensors with shape [1, num_attention_heads, sequence length, hidden_size // num_attention_heads]
where `num_vq` is the number of audio codebooks, in default setting, it is `4`.
1. Create an empty `past_key_values` with
```python
initial_kv_cache_length = 1 + model.num_spk_embs + model.streaming_text_reserved_len # where `1` denotes the `bos` token
2. At the same time, create an empty `audio_input_ids` with shape [1, sequence length, num_vq], `num_vq` denotes multiple layer audio codebooks. But here we also include text tokens in the sequence, but they will be zeros, and will not be used, just a placeholder.
streaming_tts_text_mask[0:end] = 1 # denotes these post
```
3. Generate audio codes using `generate` method.
```python
outputs = model.generate(
input_ids=audio_input_ids,
past_key_values=past_key_values,
streaming_tts_text_mask=streaming_tts_text_mask,
max_new_token=50,
)
# update past_key_values and input_ids
past_key_values = outputs.past_key_values
audio_input_ids = outputs.input_ids
```
The `past_key_values` is extended by `max_new_token=50`, and `audio_input_ids` is also extended by `max_new_token=50` after `generate` calling.
4. Notice that after prefilling `10` text tokens, the model can generate up to `50` audio tokens, if you want to generate more audio tokens, you need to prefill next `10` text tokens. And it is okay to only generate `25` audio tokens for faster initial response.
5. Repeat steps `2,3,4` as needed in your streaming audio generation cases, but ensure usage complies with the following guidelines discussed above.
"""Prefill a chunk of new text tokens in streaming setting.
Specifically speaking, update `past_key_values` using new text tokens, then the model will read the new text tokens.
Args:
input_ids (Tensor): Tensor of shape [batch_size, seq_len]
position_ids (LongTensor): Tensor of shape [batch_size, seq_len]
past_key_values (List[Tuple[Tensor]]): KV Cache of all layers, each layer is a tuple (Tensor, Tensor) denoting keys and values. Each tensor is of seq_len = `self.streaming_text_reserved_len`. `past_key_values` will be updated.
lm_spk_emb_last_hidden_states (Tensor, optional): Tensor of shape [batch_size, num_spk_emb, llm_dim]. Defaults to None.
"""Generate audio codes in streaming setting or non-streaming setting.
Specifically speaking, generate audio codes when not all text tokens are prefilled.
Always pass a valid `past_key_values` to the method. The method does not do `prefill` by itself. It relies on `prefill_text` method to provide valid `past_key_values`. Please refer to docstring of this class for more details.
In this method, we borrowed a lot of codes from `https://github.com/2noise/ChatTTS/blob/main/ChatTTS/model/gpt.py`.
Args:
input_ids (torch.Tensor): Input token ids.
past_key_values (List[Tuple[torch.Tensor, torch.Tensor]]): Past key values for attention mechanism.
temperature (torch.Tensor): Temperature for sampling.
eos_token (Union[int, torch.Tensor]): End of sequence token.
streaming_tts_text_mask (Optional[torch.Tensor], optional): Mask for streaming TTS text. Defaults to None.
max_new_token (int, optional): Maximum number of new tokens to generate. Defaults to 50.
logits_warpers (List[LogitsWarper], optional): List of logits warpers. Defaults to [].
logits_processors (List[CustomRepetitionPenaltyLogitsProcessorRepeat], optional): List of logits processors. Defaults to [].
show_tqdm (bool, optional): Whether to show progress bar. Defaults to True.