Commit 91542bfa authored by Yoach Lacombe's avatar Yoach Lacombe
Browse files

remove stable speech mentions

parent 315d3fca
# Stable Speech
# Parler-TTS
Work in-progress reproduction of the text-to-speech (TTS) model from the paper [Natural language guidance of high-fidelity text-to-speech with synthetic annotations](https://www.text-description-to-speech.com)
by Dan Lyth and Simon King, from Stability AI and Edinburgh University respectively.
......
from stable_speech import StableSpeechConfig, StableSpeechForCausalLM, StableSpeechForConditionalGeneration, StableSpeechDecoderConfig
from parler_tts import ParlerTTSConfig, ParlerTTSForCausalLM, ParlerTTSForConditionalGeneration, ParlerTTSDecoderConfig
from transformers import T5Config, EncodecConfig
from transformers import AutoConfig
......@@ -12,7 +12,7 @@ encodec = AutoConfig.from_pretrained(encodec_version)
encodec_vocab_size = encodec.codebook_size
decoder_config = StableSpeechDecoderConfig(
decoder_config = ParlerTTSDecoderConfig(
vocab_size=encodec_vocab_size+1,
max_position_embeddings=2048,
num_hidden_layers=4,
......@@ -34,13 +34,13 @@ decoder_config = StableSpeechDecoderConfig(
decoder = StableSpeechForCausalLM(decoder_config)
decoder = ParlerTTSForCausalLM(decoder_config)
decoder.save_pretrained("/raid/yoach/tmp/artefacts/decoder/")
model = StableSpeechForConditionalGeneration.from_sub_models_pretrained(
model = ParlerTTSForConditionalGeneration.from_sub_models_pretrained(
text_encoder_pretrained_model_name_or_path=text_model,
audio_encoder_pretrained_model_name_or_path=encodec_version,
decoder_pretrained_model_name_or_path="/raid/yoach/tmp/artefacts/decoder/",
......
from stable_speech import StableSpeechConfig, StableSpeechForCausalLM, StableSpeechForConditionalGeneration, StableSpeechDecoderConfig
from parler_tts import ParlerTTSConfig, ParlerTTSForCausalLM, ParlerTTSForConditionalGeneration, ParlerTTSDecoderConfig
from transformers import AutoConfig
from transformers import AutoModel
from transformers import AutoConfig, AutoModel
from stable_speech import DACConfig, DACModel
from parler_tts import DACConfig, DACModel
AutoConfig.register("dac", DACConfig)
AutoModel.register(DACConfig, DACModel)
......@@ -19,7 +19,7 @@ encodec = AutoConfig.from_pretrained(encodec_version)
encodec_vocab_size = encodec.codebook_size
decoder_config = StableSpeechDecoderConfig(
decoder_config = ParlerTTSDecoderConfig(
vocab_size=encodec_vocab_size+1,
max_position_embeddings=2048,
num_hidden_layers=4,
......@@ -41,13 +41,13 @@ decoder_config = StableSpeechDecoderConfig(
decoder = StableSpeechForCausalLM(decoder_config)
decoder = ParlerTTSForCausalLM(decoder_config)
decoder.save_pretrained("/raid/yoach/tmp/artefacts/decoder/")
model = StableSpeechForConditionalGeneration.from_sub_models_pretrained(
model = ParlerTTSForConditionalGeneration.from_sub_models_pretrained(
text_encoder_pretrained_model_name_or_path=text_model,
audio_encoder_pretrained_model_name_or_path=encodec_version,
decoder_pretrained_model_name_or_path="/raid/yoach/tmp/artefacts/decoder/",
......
from stable_speech import StableSpeechConfig, StableSpeechForCausalLM, StableSpeechForConditionalGeneration, StableSpeechDecoderConfig
from parler_tts import ParlerTTSConfig, ParlerTTSForCausalLM, ParlerTTSForConditionalGeneration, ParlerTTSDecoderConfig
from transformers import T5Config, EncodecConfig
from transformers import AutoConfig
from transformers import AutoConfig, AutoModel
from stable_speech import DACConfig, DACModel
from parler_tts import DACConfig, DACModel
AutoConfig.register("dac", DACConfig)
AutoModel.register(DACConfig, DACModel)
......@@ -20,7 +20,7 @@ encodec = AutoConfig.from_pretrained(encodec_version)
encodec_vocab_size = encodec.codebook_size
decoder_config = StableSpeechDecoderConfig(
decoder_config = ParlerTTSDecoderConfig(
vocab_size=encodec_vocab_size+1,
max_position_embeddings=3000, # 30 s = 2580
num_hidden_layers=12,
......@@ -40,13 +40,13 @@ decoder_config = StableSpeechDecoderConfig(
)
decoder = StableSpeechForCausalLM(decoder_config)
decoder = ParlerTTSForCausalLM(decoder_config)
decoder.save_pretrained("/raid/yoach/tmp/artefacts/decoder/")
model = StableSpeechForConditionalGeneration.from_sub_models_pretrained(
model = ParlerTTSForConditionalGeneration.from_sub_models_pretrained(
text_encoder_pretrained_model_name_or_path=text_model,
audio_encoder_pretrained_model_name_or_path=encodec_version,
decoder_pretrained_model_name_or_path="/raid/yoach/tmp/artefacts/decoder/",
......
from stable_speech import StableSpeechConfig, StableSpeechForCausalLM, StableSpeechForConditionalGeneration, StableSpeechDecoderConfig
from parler_tts import ParlerTTSConfig, ParlerTTSForCausalLM, ParlerTTSForConditionalGeneration, ParlerTTSDecoderConfig
from transformers import T5Config, EncodecConfig
from transformers import AutoConfig
from transformers import AutoConfig, AutoModel
from stable_speech import DACConfig, DACModel
from parler_tts import DACConfig, DACModel
AutoConfig.register("dac", DACConfig)
AutoModel.register(DACConfig, DACModel)
......@@ -20,7 +20,7 @@ encodec = AutoConfig.from_pretrained(encodec_version)
encodec_vocab_size = encodec.codebook_size
decoder_config = StableSpeechDecoderConfig(
decoder_config = ParlerTTSDecoderConfig(
vocab_size=encodec_vocab_size+1,
max_position_embeddings=4096, # 30 s = 2580
num_hidden_layers=8,
......@@ -40,13 +40,13 @@ decoder_config = StableSpeechDecoderConfig(
)
decoder = StableSpeechForCausalLM(decoder_config)
decoder = ParlerTTSForCausalLM(decoder_config)
decoder.save_pretrained("/raid/yoach/tmp/artefacts/decoder_small/")
model = StableSpeechForConditionalGeneration.from_sub_models_pretrained(
model = ParlerTTSForConditionalGeneration.from_sub_models_pretrained(
text_encoder_pretrained_model_name_or_path=text_model,
audio_encoder_pretrained_model_name_or_path=encodec_version,
decoder_pretrained_model_name_or_path="/raid/yoach/tmp/artefacts/decoder_small/",
......
from .configuration_stable_speech import StableSpeechConfig, StableSpeechDecoderConfig
from .modeling_stable_speech import StableSpeechForCausalLM, StableSpeechForConditionalGeneration, apply_delay_pattern_mask, build_delay_pattern_mask
from .configuration_parler_tts import ParlerTTSConfig, ParlerTTSDecoderConfig
from .modeling_parler_tts import ParlerTTSForCausalLM, ParlerTTSForConditionalGeneration, apply_delay_pattern_mask, build_delay_pattern_mask
from .dac_wrapper import DACConfig, DACModel
\ No newline at end of file
......@@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Stable Speech model configuration"""
""" Parler-TTS model configuration"""
from transformers import AutoConfig, logging
from transformers.configuration_utils import PretrainedConfig
......@@ -21,17 +21,17 @@ from transformers.configuration_utils import PretrainedConfig
logger = logging.get_logger(__name__)
MUSICGEN_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"facebook/stable_speech-small": "https://huggingface.co/facebook/stable_speech-small/resolve/main/config.json",
# See all StableSpeech models at https://huggingface.co/models?filter=stable_speech
"facebook/parler_tts-small": "https://huggingface.co/facebook/parler_tts-small/resolve/main/config.json",
# See all ParlerTTS models at https://huggingface.co/models?filter=parler_tts
}
class StableSpeechDecoderConfig(PretrainedConfig):
class ParlerTTSDecoderConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of an [`StableSpeechDecoder`]. It is used to instantiate a
Stable Speech decoder according to the specified arguments, defining the model architecture. Instantiating a
configuration with the defaults will yield a similar configuration to that of the Stable Speech
[facebook/stable_speech-small](https://huggingface.co/facebook/stable_speech-small) architecture.
This is the configuration class to store the configuration of an [`ParlerTTSDecoder`]. It is used to instantiate a
Parler-TTS decoder according to the specified arguments, defining the model architecture. Instantiating a
configuration with the defaults will yield a similar configuration to that of the Parler-TTS
[facebook/parler_tts-small](https://huggingface.co/facebook/parler_tts-small) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
......@@ -39,8 +39,8 @@ class StableSpeechDecoderConfig(PretrainedConfig):
Args:
vocab_size (`int`, *optional*, defaults to 2049):
Vocabulary size of the StableSpeechDecoder model. Defines the number of different tokens that can be
represented by the `inputs_ids` passed when calling [`StableSpeechDecoder`].
Vocabulary size of the ParlerTTSDecoder model. Defines the number of different tokens that can be
represented by the `inputs_ids` passed when calling [`ParlerTTSDecoder`].
hidden_size (`int`, *optional*, defaults to 1024):
Dimensionality of the layers and the pooler layer.
num_hidden_layers (`int`, *optional*, defaults to 24):
......@@ -76,7 +76,7 @@ class StableSpeechDecoderConfig(PretrainedConfig):
Whether input and output word embeddings should be tied.
"""
model_type = "stable_speech_decoder"
model_type = "parler_tts_decoder"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
......@@ -127,10 +127,10 @@ class StableSpeechDecoderConfig(PretrainedConfig):
)
class StableSpeechConfig(PretrainedConfig):
class ParlerTTSConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`StableSpeechModel`]. It is used to instantiate a
Stable Speech model according to the specified arguments, defining the text encoder, audio encoder and Stable Speech decoder
This is the configuration class to store the configuration of a [`ParlerTTSModel`]. It is used to instantiate a
Parler-TTS model according to the specified arguments, defining the text encoder, audio encoder and Parler-TTS decoder
configs.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
......@@ -153,24 +153,24 @@ class StableSpeechConfig(PretrainedConfig):
```python
>>> from transformers import (
... StableSpeechConfig,
... StableSpeechDecoderConfig,
... ParlerTTSConfig,
... ParlerTTSDecoderConfig,
... T5Config,
... EncodecConfig,
... StableSpeechForConditionalGeneration,
... ParlerTTSForConditionalGeneration,
... )
>>> # Initializing text encoder, audio encoder, and decoder model configurations
>>> text_encoder_config = T5Config()
>>> audio_encoder_config = EncodecConfig()
>>> decoder_config = StableSpeechDecoderConfig()
>>> decoder_config = ParlerTTSDecoderConfig()
>>> configuration = StableSpeechConfig.from_sub_models_config(
>>> configuration = ParlerTTSConfig.from_sub_models_config(
... text_encoder_config, audio_encoder_config, decoder_config
... )
>>> # Initializing a StableSpeechForConditionalGeneration (with random weights) from the facebook/stable_speech-small style configuration
>>> model = StableSpeechForConditionalGeneration(configuration)
>>> # Initializing a ParlerTTSForConditionalGeneration (with random weights) from the facebook/parler_tts-small style configuration
>>> model = ParlerTTSForConditionalGeneration(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
......@@ -179,14 +179,14 @@ class StableSpeechConfig(PretrainedConfig):
>>> config_decoder = model.config.decoder
>>> # Saving the model, including its configuration
>>> model.save_pretrained("stable_speech-model")
>>> model.save_pretrained("parler_tts-model")
>>> # loading model and config from pretrained folder
>>> stable_speech_config = StableSpeechConfig.from_pretrained("stable_speech-model")
>>> model = StableSpeechForConditionalGeneration.from_pretrained("stable_speech-model", config=stable_speech_config)
>>> parler_tts_config = ParlerTTSConfig.from_pretrained("parler_tts-model")
>>> model = ParlerTTSForConditionalGeneration.from_pretrained("parler_tts-model", config=parler_tts_config)
```"""
model_type = "stable_speech"
model_type = "parler_tts"
is_composition = True
def __init__(self, vocab_size=1024, **kwargs):
......@@ -205,7 +205,7 @@ class StableSpeechConfig(PretrainedConfig):
self.vocab_size = vocab_size
self.text_encoder = AutoConfig.for_model(text_encoder_model_type, **text_encoder_config)
self.audio_encoder = AutoConfig.for_model(audio_encoder_model_type, **audio_encoder_config)
self.decoder = StableSpeechDecoderConfig(**decoder_config)
self.decoder = ParlerTTSDecoderConfig(**decoder_config)
self.is_encoder_decoder = True
@classmethod
......@@ -213,15 +213,15 @@ class StableSpeechConfig(PretrainedConfig):
cls,
text_encoder_config: PretrainedConfig,
audio_encoder_config: PretrainedConfig,
decoder_config: StableSpeechDecoderConfig,
decoder_config: ParlerTTSDecoderConfig,
**kwargs,
):
r"""
Instantiate a [`StableSpeechConfig`] (or a derived class) from text encoder, audio encoder and decoder
Instantiate a [`ParlerTTSConfig`] (or a derived class) from text encoder, audio encoder and decoder
configurations.
Returns:
[`StableSpeechConfig`]: An instance of a configuration object
[`ParlerTTSConfig`]: An instance of a configuration object
"""
return cls(
......
......@@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch StableSpeech model."""
""" PyTorch ParlerTTS model."""
import copy
import inspect
import math
......@@ -44,7 +44,7 @@ from transformers.utils import (
replace_return_docstrings,
)
from .configuration_stable_speech import StableSpeechConfig, StableSpeechDecoderConfig
from .configuration_parler_tts import ParlerTTSConfig, ParlerTTSDecoderConfig
if TYPE_CHECKING:
......@@ -52,12 +52,12 @@ if TYPE_CHECKING:
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "StableSpeechConfig"
_CHECKPOINT_FOR_DOC = "facebook/stable_speech-small"
_CONFIG_FOR_DOC = "ParlerTTSConfig"
_CHECKPOINT_FOR_DOC = "facebook/parler_tts-small"
MUSICGEN_PRETRAINED_MODEL_ARCHIVE_LIST = [
"facebook/stable_speech-small",
# See all StableSpeech models at https://huggingface.co/models?filter=stable_speech
"facebook/parler_tts-small",
# See all ParlerTTS models at https://huggingface.co/models?filter=parler_tts
]
def apply_delay_pattern_mask(input_ids, decoder_pad_token_mask):
......@@ -133,7 +133,7 @@ def build_delay_pattern_mask(input_ids: torch.LongTensor, bos_token_id: int, pad
return input_ids, pattern_mask
@dataclass
class StableSpeechUnconditionalInput(ModelOutput):
class ParlerTTSUnconditionalInput(ModelOutput):
"""
Args:
encoder_outputs (`Tuple[torch.FloatTensor]` of length 1, with tensor shape `(batch_size, sequence_length, hidden_size)`):
......@@ -170,8 +170,8 @@ def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start
return shifted_input_ids
# Copied from transformers.models.musicgen.modeling_musicgen.MusicgenSinusoidalPositionalEmbedding with Musicgen->StableSpeech
class StableSpeechSinusoidalPositionalEmbedding(nn.Module):
# Copied from transformers.models.musicgen.modeling_musicgen.MusicgenSinusoidalPositionalEmbedding with Musicgen->ParlerTTS
class ParlerTTSSinusoidalPositionalEmbedding(nn.Module):
"""This module produces sinusoidal positional embeddings of any length."""
def __init__(self, num_positions: int, embedding_dim: int):
......@@ -217,8 +217,8 @@ class StableSpeechSinusoidalPositionalEmbedding(nn.Module):
return self.weights.index_select(0, position_ids.view(-1)).detach()
# Copied from transformers.models.musicgen.modeling_musicgen.MusicgenAttention with Musicgen->StableSpeech
class StableSpeechAttention(nn.Module):
# Copied from transformers.models.musicgen.modeling_musicgen.MusicgenAttention with Musicgen->ParlerTTS
class ParlerTTSAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(
......@@ -229,7 +229,7 @@ class StableSpeechAttention(nn.Module):
is_decoder: bool = False,
bias: bool = True,
is_causal: bool = False,
config: Optional[StableSpeechConfig] = None,
config: Optional[ParlerTTSConfig] = None,
):
super().__init__()
self.embed_dim = embed_dim
......@@ -376,13 +376,13 @@ class StableSpeechAttention(nn.Module):
return attn_output, attn_weights_reshaped, past_key_value
# Copied from transformers.models.musicgen.modeling_musicgen.MusicgenDecoderLayer with Musicgen->StableSpeech
class StableSpeechDecoderLayer(nn.Module):
def __init__(self, config: StableSpeechDecoderConfig):
# Copied from transformers.models.musicgen.modeling_musicgen.MusicgenDecoderLayer with Musicgen->ParlerTTS
class ParlerTTSDecoderLayer(nn.Module):
def __init__(self, config: ParlerTTSDecoderConfig):
super().__init__()
self.embed_dim = config.hidden_size
self.self_attn = StableSpeechAttention(
self.self_attn = ParlerTTSAttention(
embed_dim=self.embed_dim,
num_heads=config.num_attention_heads,
dropout=config.attention_dropout,
......@@ -394,7 +394,7 @@ class StableSpeechDecoderLayer(nn.Module):
self.activation_dropout = config.activation_dropout
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.encoder_attn = StableSpeechAttention(
self.encoder_attn = ParlerTTSAttention(
self.embed_dim,
config.num_attention_heads,
dropout=config.attention_dropout,
......@@ -496,17 +496,17 @@ class StableSpeechDecoderLayer(nn.Module):
return outputs
# Copied from transformers.models.musicgen.modeling_musicgen.MusicgenPreTrainedModel with Musicgen->StableSpeech
class StableSpeechPreTrainedModel(PreTrainedModel):
# Copied from transformers.models.musicgen.modeling_musicgen.MusicgenPreTrainedModel with Musicgen->ParlerTTS
class ParlerTTSPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = StableSpeechDecoderConfig
config_class = ParlerTTSDecoderConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["StableSpeechDecoderLayer", "StableSpeechAttention"]
_no_split_modules = ["ParlerTTSDecoderLayer", "ParlerTTSAttention"]
def _init_weights(self, module):
std = self.config.initializer_factor
......@@ -522,7 +522,7 @@ class StableSpeechPreTrainedModel(PreTrainedModel):
MUSICGEN_START_DOCSTRING = r"""
The StableSpeech model was proposed in [Simple and Controllable Music Generation](https://arxiv.org/abs/2306.05284) by
The ParlerTTS model was proposed in [Simple and Controllable Music Generation](https://arxiv.org/abs/2306.05284) by
Jade Copet, Felix Kreuk, Itai Gat, Tal Remez, David Kant, Gabriel Synnaeve, Yossi Adi, Alexandre Défossez. It is an
encoder decoder transformer trained on the task of conditional music generation
......@@ -535,7 +535,7 @@ MUSICGEN_START_DOCSTRING = r"""
and behavior.
Parameters:
config ([`StableSpeechConfig`]): Model configuration class with all the parameters of the model.
config ([`ParlerTTSConfig`]): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
......@@ -716,13 +716,13 @@ MUSICGEN_DECODER_INPUTS_DOCSTRING = r"""
"""
# Copied from transformers.models.musicgen.modeling_musicgen.MusicgenDecoder with Musicgen->StableSpeech
class StableSpeechDecoder(StableSpeechPreTrainedModel):
# Copied from transformers.models.musicgen.modeling_musicgen.MusicgenDecoder with Musicgen->ParlerTTS
class ParlerTTSDecoder(ParlerTTSPreTrainedModel):
"""
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`StableSpeechDecoderLayer`]
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`ParlerTTSDecoderLayer`]
"""
def __init__(self, config: StableSpeechDecoderConfig):
def __init__(self, config: ParlerTTSDecoderConfig):
super().__init__(config)
self.dropout = config.dropout
self.layerdrop = config.layerdrop
......@@ -737,12 +737,12 @@ class StableSpeechDecoder(StableSpeechPreTrainedModel):
[nn.Embedding(embed_dim, config.hidden_size) for _ in range(config.num_codebooks)]
)
self.embed_positions = StableSpeechSinusoidalPositionalEmbedding(
self.embed_positions = ParlerTTSSinusoidalPositionalEmbedding(
config.max_position_embeddings,
config.hidden_size,
)
self.layers = nn.ModuleList([StableSpeechDecoderLayer(config) for _ in range(config.num_hidden_layers)])
self.layers = nn.ModuleList([ParlerTTSDecoderLayer(config) for _ in range(config.num_hidden_layers)])
self.layer_norm = nn.LayerNorm(config.hidden_size)
self.gradient_checkpointing = False
......@@ -930,14 +930,14 @@ class StableSpeechDecoder(StableSpeechPreTrainedModel):
@add_start_docstrings(
"The bare StableSpeech decoder model outputting raw hidden-states without any specific head on top.",
"The bare ParlerTTS decoder model outputting raw hidden-states without any specific head on top.",
MUSICGEN_START_DOCSTRING,
)
# Copied from transformers.models.musicgen.modeling_musicgen.MusicgenModel with Musicgen->StableSpeech
class StableSpeechModel(StableSpeechPreTrainedModel):
def __init__(self, config: StableSpeechDecoderConfig):
# Copied from transformers.models.musicgen.modeling_musicgen.MusicgenModel with Musicgen->ParlerTTS
class ParlerTTSModel(ParlerTTSPreTrainedModel):
def __init__(self, config: ParlerTTSDecoderConfig):
super().__init__(config)
self.decoder = StableSpeechDecoder(config)
self.decoder = ParlerTTSDecoder(config)
# Initialize weights and apply final processing
self.post_init()
......@@ -1006,15 +1006,15 @@ class StableSpeechModel(StableSpeechPreTrainedModel):
@add_start_docstrings(
"The Stable Speech decoder model with a language modelling head on top.",
"The Parler-TTS decoder model with a language modelling head on top.",
MUSICGEN_START_DOCSTRING,
)
# Copied from transformers.models.musicgen.modeling_musicgen.MusicgenForCausalLM with Musicgen->StableSpeech
class StableSpeechForCausalLM(StableSpeechPreTrainedModel):
def __init__(self, config: StableSpeechDecoderConfig):
# Copied from transformers.models.musicgen.modeling_musicgen.MusicgenForCausalLM with Musicgen->ParlerTTS
class ParlerTTSForCausalLM(ParlerTTSPreTrainedModel):
def __init__(self, config: ParlerTTSDecoderConfig):
super().__init__(config)
self.model = StableSpeechModel(config)
self.model = ParlerTTSModel(config)
self.num_codebooks = config.num_codebooks
self.lm_heads = nn.ModuleList(
......@@ -1379,7 +1379,7 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel):
)
# 6. Prepare `input_ids` which will be used for auto-regressive generation
# Build the delay pattern mask for offsetting each codebook prediction by 1 (this behaviour is specific to Stable Speech)
# Build the delay pattern mask for offsetting each codebook prediction by 1 (this behaviour is specific to Parler-TTS)
input_ids, delay_pattern_mask = self.build_delay_pattern_mask(
input_ids,
bos_token_id=generation_config.bos_token_id,
......@@ -1498,29 +1498,29 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel):
@add_start_docstrings(
"The composite Stable Speech model with a text encoder, audio encoder and StableSpeech decoder, "
"The composite Parler-TTS model with a text encoder, audio encoder and ParlerTTS decoder, "
"for music generation tasks with one or both of text and audio prompts.",
MUSICGEN_START_DOCSTRING,
)
class StableSpeechForConditionalGeneration(PreTrainedModel):
config_class = StableSpeechConfig
class ParlerTTSForConditionalGeneration(PreTrainedModel):
config_class = ParlerTTSConfig
base_model_prefix = "encoder_decoder"
main_input_name = "input_ids"
supports_gradient_checkpointing = True
def __init__(
self,
config: Optional[StableSpeechConfig] = None,
config: Optional[ParlerTTSConfig] = None,
text_encoder: Optional[PreTrainedModel] = None,
audio_encoder: Optional[PreTrainedModel] = None,
decoder: Optional[StableSpeechForCausalLM] = None,
decoder: Optional[ParlerTTSForCausalLM] = None,
):
if config is None and (text_encoder is None or audio_encoder is None or decoder is None):
raise ValueError(
"Either a configuration has to be provided, or all three of text encoder, audio encoder and Stable Speech decoder."
"Either a configuration has to be provided, or all three of text encoder, audio encoder and Parler-TTS decoder."
)
if config is None:
config = StableSpeechConfig.from_sub_models_config(
config = ParlerTTSConfig.from_sub_models_config(
text_encoder.config, audio_encoder.config, decoder.config
)
else:
......@@ -1530,7 +1530,7 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
if config.decoder.cross_attention_hidden_size is not None:
if config.decoder.cross_attention_hidden_size != config.text_encoder.hidden_size:
raise ValueError(
"If `cross_attention_hidden_size` is specified in the Stable Speech decoder's configuration, it has to be equal"
"If `cross_attention_hidden_size` is specified in the Parler-TTS decoder's configuration, it has to be equal"
f" to the text encoder's `hidden_size`. Got {config.decoder.cross_attention_hidden_size} for"
f" `config.decoder.cross_attention_hidden_size` and {config.text_encoder.hidden_size} for"
" `config.text_encoder.hidden_size`."
......@@ -1550,7 +1550,7 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
audio_encoder = AutoModel.from_config(config.audio_encoder)
if decoder is None:
decoder = StableSpeechForCausalLM(config.decoder)
decoder = ParlerTTSForCausalLM(config.decoder)
self.text_encoder = text_encoder
self.audio_encoder = audio_encoder
......@@ -1652,15 +1652,15 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
Example:
```python
>>> from transformers import StableSpeechForConditionalGeneration
>>> from transformers import ParlerTTSForConditionalGeneration
>>> model = StableSpeechForConditionalGeneration.from_pretrained("facebook/stable_speech-small")
>>> model = ParlerTTSForConditionalGeneration.from_pretrained("facebook/parler_tts-small")
```"""
# At the moment fast initialization is not supported for composite models
if kwargs.get("_fast_init", False):
logger.warning(
"Fast initialization is currently not supported for StableSpeechForConditionalGeneration. "
"Fast initialization is currently not supported for ParlerTTSForConditionalGeneration. "
"Falling back to slow initialization..."
)
kwargs["_fast_init"] = False
......@@ -1677,7 +1677,7 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
**kwargs,
) -> PreTrainedModel:
r"""
Instantiate a text encoder, an audio encoder, and a Stable Speech decoder from one, two or three base classes of the
Instantiate a text encoder, an audio encoder, and a Parler-TTS decoder from one, two or three base classes of the
library from pretrained model checkpoints.
......@@ -1708,7 +1708,7 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
- A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
Valid model ids can be located at the root-level, like `gpt2`, or namespaced under a user or
organization name, like `facebook/stable_speech-small`.
organization name, like `facebook/parler_tts-small`.
- A path to a *directory* containing model weights saved using
[`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
......@@ -1731,18 +1731,18 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
Example:
```python
>>> from transformers import StableSpeechForConditionalGeneration
>>> from transformers import ParlerTTSForConditionalGeneration
>>> # initialize a stable_speech model from a t5 text encoder, encodec audio encoder, and stable_speech decoder
>>> model = StableSpeechForConditionalGeneration.from_sub_models_pretrained(
>>> # initialize a parler_tts model from a t5 text encoder, encodec audio encoder, and parler_tts decoder
>>> model = ParlerTTSForConditionalGeneration.from_sub_models_pretrained(
... text_encoder_pretrained_model_name_or_path="t5-base",
... audio_encoder_pretrained_model_name_or_path="facebook/encodec_24khz",
... decoder_pretrained_model_name_or_path="facebook/stable_speech-small",
... decoder_pretrained_model_name_or_path="facebook/parler_tts-small",
... )
>>> # saving model after fine-tuning
>>> model.save_pretrained("./stable_speech-ft")
>>> model.save_pretrained("./parler_tts-ft")
>>> # load fine-tuned model
>>> model = StableSpeechForConditionalGeneration.from_pretrained("./stable_speech-ft")
>>> model = ParlerTTSForConditionalGeneration.from_pretrained("./parler_tts-ft")
```"""
kwargs_text_encoder = {
......@@ -1836,11 +1836,11 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
if "config" not in kwargs_decoder:
# TODO: reput AutoConfig once added to transformers
decoder_config, kwargs_decoder = StableSpeechDecoderConfig.from_pretrained(
decoder_config, kwargs_decoder = ParlerTTSDecoderConfig.from_pretrained(
decoder_pretrained_model_name_or_path, **kwargs_decoder, return_unused_kwargs=True
)
if isinstance(decoder_config, StableSpeechConfig):
if isinstance(decoder_config, ParlerTTSConfig):
decoder_config = decoder_config.decoder
if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False:
......@@ -1863,10 +1863,10 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
"`decoder_config` to `.from_sub_models_pretrained(...)`"
)
decoder = StableSpeechForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder)
decoder = ParlerTTSForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder)
# instantiate config with corresponding kwargs
config = StableSpeechConfig.from_sub_models_config(
config = ParlerTTSConfig.from_sub_models_config(
text_encoder.config, audio_encoder.config, decoder.config, **kwargs
)
return cls(text_encoder=text_encoder, audio_encoder=audio_encoder, decoder=decoder, config=config)
......@@ -1900,11 +1900,11 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
Examples:
```python
>>> from transformers import AutoProcessor, StableSpeechForConditionalGeneration
>>> from transformers import AutoProcessor, ParlerTTSForConditionalGeneration
>>> import torch
>>> processor = AutoProcessor.from_pretrained("facebook/stable_speech-small")
>>> model = StableSpeechForConditionalGeneration.from_pretrained("facebook/stable_speech-small")
>>> processor = AutoProcessor.from_pretrained("facebook/parler_tts-small")
>>> model = ParlerTTSForConditionalGeneration.from_pretrained("facebook/parler_tts-small")
>>> inputs = processor(
... text=["80s pop track with bassy drums and synth", "90s rock song with loud guitars and heavy drums"],
......@@ -2479,7 +2479,7 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
" increasing `max_new_tokens`."
)
# build the delay pattern mask for offsetting each codebook prediction by 1 (this behaviour is specific to Stable Speech)
# build the delay pattern mask for offsetting each codebook prediction by 1 (this behaviour is specific to Parler-TTS)
input_ids, decoder_delay_pattern_mask = self.decoder.build_delay_pattern_mask(
input_ids,
bos_token_id=generation_config.bos_token_id,
......@@ -2649,9 +2649,9 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
Example:
```python
>>> from transformers import StableSpeechForConditionalGeneration
>>> from transformers import ParlerTTSForConditionalGeneration
>>> model = StableSpeechForConditionalGeneration.from_pretrained("facebook/stable_speech-small")
>>> model = ParlerTTSForConditionalGeneration.from_pretrained("facebook/parler_tts-small")
>>> # get the unconditional (or 'null') inputs for the model
>>> unconditional_inputs = model.get_unconditional_inputs(num_samples=1)
......@@ -2663,7 +2663,7 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
attention_mask = torch.zeros((num_samples, 1), device=self.device, dtype=torch.long)
return StableSpeechUnconditionalInput(
return ParlerTTSUnconditionalInput(
encoder_outputs=(last_hidden_state,),
attention_mask=attention_mask,
guidance_scale=1.0,
......
......@@ -4,7 +4,7 @@ import dac
model_path = dac.utils.download(model_type="44khz")
model = dac.DAC.load(model_path)
from stable_speech import DACConfig, DACModel
from parler_tts import DACConfig, DACModel
hf_dac = DACModel(DACConfig())
hf_dac.model.load_state_dict(model.state_dict())
......
......@@ -63,7 +63,7 @@ from transformers.utils import check_min_version, send_example_telemetry
from transformers.utils.versions import require_version
from transformers.integrations import is_wandb_available
from transformers import AutoConfig, AutoModel
from stable_speech import DACConfig, DACModel
from parler_tts import DACConfig, DACModel
from transformers.modeling_outputs import BaseModelOutput
AutoConfig.register("dac", DACConfig)
AutoModel.register(DACConfig, DACModel)
......@@ -73,7 +73,7 @@ from accelerate import Accelerator
from accelerate.utils import set_seed, AutocastKwargs, InitProcessGroupKwargs, TorchDynamoPlugin
from accelerate.utils.memory import release_memory
from stable_speech import StableSpeechForConditionalGeneration, StableSpeechConfig, apply_delay_pattern_mask, build_delay_pattern_mask
from parler_tts import ParlerTTSForConditionalGeneration, ParlerTTSConfig, apply_delay_pattern_mask, build_delay_pattern_mask
if is_wandb_available():
from wandb import Audio
......@@ -455,7 +455,7 @@ class DataTrainingArguments:
}
)
wandb_project: str = field(
default="stable-speech",
default="parler-speech",
metadata={"help": "The name of the wandb project."},
)
save_to_disk: str = field(
......@@ -480,7 +480,7 @@ class DataTrainingArguments:
)
@dataclass
class StableSpeechTrainingArguments(Seq2SeqTrainingArguments):
class ParlerTTSTrainingArguments(Seq2SeqTrainingArguments):
dtype: Optional[str] = field(
default="float32",
metadata={
......@@ -525,7 +525,7 @@ class DataCollatorEncodecWithPadding:
@dataclass
class DataCollatorStableSpeechWithPadding:
class DataCollatorParlerTTSWithPadding:
"""
Data collator that will dynamically pad the inputs received.
Args:
......@@ -716,14 +716,14 @@ def load_multiple_datasets(
# TODO(YL): I forgot to create unique ids for MLS english.
# To iterate faster, I bypass the original id check and do another one. - Done once because assuming it won't change next time
# if dataset_dict["name"] == "stable-speech/mls_eng_10k":
# if dataset_dict["name"] == "parler-tts/mls_eng_10k":
# def concat_ids(book_id, speaker_id, begin_time):
# return {"id": f"{book_id}_{speaker_id}_{str(begin_time).replace('.', '_')}"}
# dataset = dataset.map(concat_ids, input_columns=["book_id", "speaker_id", "begin_time"], num_proc=24)
# metadata_dataset = metadata_dataset.map(concat_ids, input_columns=["book_id", "speaker_id", "begin_time"], num_proc=24)
# metadata_dataset = metadata_dataset.rename_column(id_column_name, f"metadata_{id_column_name}")
if dataset_dict["name"] != "stable-speech/mls_eng_10k":
if dataset_dict["name"] != "parler-tts/mls_eng_10k":
if id_column_name is not None and id_column_name not in dataset.column_names:
raise ValueError(
f"id_column_name={id_column_name} but has not been found in the dataset columns"
......@@ -751,7 +751,7 @@ def load_multiple_datasets(
dataset = concatenate_datasets([dataset, metadata_dataset], axis=1)
if id_column_name is not None and dataset_dict["name"] != "stable-speech/mls_eng_10k":
if id_column_name is not None and dataset_dict["name"] != "parler-tts/mls_eng_10k":
if len(dataset.filter(lambda id1, id2: id1!=id2, input_columns=[id_column_name, f"metadata_{id_column_name}"])) != 0:
raise ValueError(f"Concatenate didn't work. Some ids don't correspond on dataset {dataset_dict['name']}")
......@@ -785,7 +785,7 @@ def main():
# or by passing the --help flag to this script.
# We now keep distinct sets of args, for a cleaner separation of concerns.
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, StableSpeechTrainingArguments))
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, ParlerTTSTrainingArguments))
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
# If we pass only one argument to the script and it's the path to a json file,
# let's parse it to get our arguments.
......@@ -795,7 +795,7 @@ def main():
# Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
# information sent is the one passed as arguments along with your Python/PyTorch versions.
send_example_telemetry("run_stable_speech", model_args, data_args)
send_example_telemetry("run_parler_tts", model_args, data_args)
if training_args.dtype == "float16":
mixed_precision = "fp16"
......@@ -996,7 +996,7 @@ def main():
# 3. Next, let's load the config.
# TODO(YL): add the option to create the config from scratch
config = StableSpeechConfig.from_pretrained(
config = ParlerTTSConfig.from_pretrained(
model_args.model_name_or_path,
cache_dir=model_args.cache_dir,
token=data_args.token,
......@@ -1011,7 +1011,7 @@ def main():
})
# create model + TODO(YL): not from_pretrained probably
model = StableSpeechForConditionalGeneration.from_pretrained(
model = ParlerTTSForConditionalGeneration.from_pretrained(
model_args.model_name_or_path,
cache_dir=model_args.cache_dir,
config=config,
......@@ -1353,7 +1353,7 @@ def main():
)
# Instantiate custom data collator
data_collator = DataCollatorStableSpeechWithPadding(
data_collator = DataCollatorParlerTTSWithPadding(
audio_feature_extractor=feature_extractor, feature_extractor_input_name=feature_extractor_input_name, prompt_tokenizer=prompt_tokenizer, description_tokenizer=description_tokenizer, pad_to_multiple_of=data_args.pad_to_multiple_of,
padding=padding, prompt_max_length=data_args.max_prompt_token_length, description_max_length=data_args.max_description_token_length, audio_max_length = audio_max_length
)
......
......@@ -41,7 +41,7 @@ with open(os.path.join(here, "README.md"), encoding="utf-8") as f:
long_description = f.read()
# read version
with open(os.path.join(here, "stable_speech", "__init__.py"), encoding="utf-8") as f:
with open(os.path.join(here, "parler_tts", "__init__.py"), encoding="utf-8") as f:
for line in f:
if line.startswith("__version__"):
version = line.split("=")[1].strip().strip('"')
......@@ -50,7 +50,7 @@ with open(os.path.join(here, "stable_speech", "__init__.py"), encoding="utf-8")
raise RuntimeError("Unable to find version string.")
setuptools.setup(
name="stable_speech",
name="parler_tts",
version=version,
description="Toolkit for reproducing Stability AI's text-to-speech model.",
long_description=long_description,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment