"src/targets/vscode:/vscode.git/clone" did not exist on "f279cda6ade4a1cca67d503ff4a3576c23ee7dc6"
Commit 91542bfa authored by Yoach Lacombe's avatar Yoach Lacombe
Browse files

remove stable speech mentions

parent 315d3fca
# Stable Speech # Parler-TTS
Work in-progress reproduction of the text-to-speech (TTS) model from the paper [Natural language guidance of high-fidelity text-to-speech with synthetic annotations](https://www.text-description-to-speech.com) Work in-progress reproduction of the text-to-speech (TTS) model from the paper [Natural language guidance of high-fidelity text-to-speech with synthetic annotations](https://www.text-description-to-speech.com)
by Dan Lyth and Simon King, from Stability AI and Edinburgh University respectively. by Dan Lyth and Simon King, from Stability AI and Edinburgh University respectively.
......
from stable_speech import StableSpeechConfig, StableSpeechForCausalLM, StableSpeechForConditionalGeneration, StableSpeechDecoderConfig from parler_tts import ParlerTTSConfig, ParlerTTSForCausalLM, ParlerTTSForConditionalGeneration, ParlerTTSDecoderConfig
from transformers import T5Config, EncodecConfig from transformers import T5Config, EncodecConfig
from transformers import AutoConfig from transformers import AutoConfig
...@@ -12,7 +12,7 @@ encodec = AutoConfig.from_pretrained(encodec_version) ...@@ -12,7 +12,7 @@ encodec = AutoConfig.from_pretrained(encodec_version)
encodec_vocab_size = encodec.codebook_size encodec_vocab_size = encodec.codebook_size
decoder_config = StableSpeechDecoderConfig( decoder_config = ParlerTTSDecoderConfig(
vocab_size=encodec_vocab_size+1, vocab_size=encodec_vocab_size+1,
max_position_embeddings=2048, max_position_embeddings=2048,
num_hidden_layers=4, num_hidden_layers=4,
...@@ -34,13 +34,13 @@ decoder_config = StableSpeechDecoderConfig( ...@@ -34,13 +34,13 @@ decoder_config = StableSpeechDecoderConfig(
decoder = StableSpeechForCausalLM(decoder_config) decoder = ParlerTTSForCausalLM(decoder_config)
decoder.save_pretrained("/raid/yoach/tmp/artefacts/decoder/") decoder.save_pretrained("/raid/yoach/tmp/artefacts/decoder/")
model = StableSpeechForConditionalGeneration.from_sub_models_pretrained( model = ParlerTTSForConditionalGeneration.from_sub_models_pretrained(
text_encoder_pretrained_model_name_or_path=text_model, text_encoder_pretrained_model_name_or_path=text_model,
audio_encoder_pretrained_model_name_or_path=encodec_version, audio_encoder_pretrained_model_name_or_path=encodec_version,
decoder_pretrained_model_name_or_path="/raid/yoach/tmp/artefacts/decoder/", decoder_pretrained_model_name_or_path="/raid/yoach/tmp/artefacts/decoder/",
......
from stable_speech import StableSpeechConfig, StableSpeechForCausalLM, StableSpeechForConditionalGeneration, StableSpeechDecoderConfig from parler_tts import ParlerTTSConfig, ParlerTTSForCausalLM, ParlerTTSForConditionalGeneration, ParlerTTSDecoderConfig
from transformers import AutoConfig from transformers import AutoConfig
from transformers import AutoModel from transformers import AutoModel
from transformers import AutoConfig, AutoModel from transformers import AutoConfig, AutoModel
from stable_speech import DACConfig, DACModel from parler_tts import DACConfig, DACModel
AutoConfig.register("dac", DACConfig) AutoConfig.register("dac", DACConfig)
AutoModel.register(DACConfig, DACModel) AutoModel.register(DACConfig, DACModel)
...@@ -19,7 +19,7 @@ encodec = AutoConfig.from_pretrained(encodec_version) ...@@ -19,7 +19,7 @@ encodec = AutoConfig.from_pretrained(encodec_version)
encodec_vocab_size = encodec.codebook_size encodec_vocab_size = encodec.codebook_size
decoder_config = StableSpeechDecoderConfig( decoder_config = ParlerTTSDecoderConfig(
vocab_size=encodec_vocab_size+1, vocab_size=encodec_vocab_size+1,
max_position_embeddings=2048, max_position_embeddings=2048,
num_hidden_layers=4, num_hidden_layers=4,
...@@ -41,13 +41,13 @@ decoder_config = StableSpeechDecoderConfig( ...@@ -41,13 +41,13 @@ decoder_config = StableSpeechDecoderConfig(
decoder = StableSpeechForCausalLM(decoder_config) decoder = ParlerTTSForCausalLM(decoder_config)
decoder.save_pretrained("/raid/yoach/tmp/artefacts/decoder/") decoder.save_pretrained("/raid/yoach/tmp/artefacts/decoder/")
model = StableSpeechForConditionalGeneration.from_sub_models_pretrained( model = ParlerTTSForConditionalGeneration.from_sub_models_pretrained(
text_encoder_pretrained_model_name_or_path=text_model, text_encoder_pretrained_model_name_or_path=text_model,
audio_encoder_pretrained_model_name_or_path=encodec_version, audio_encoder_pretrained_model_name_or_path=encodec_version,
decoder_pretrained_model_name_or_path="/raid/yoach/tmp/artefacts/decoder/", decoder_pretrained_model_name_or_path="/raid/yoach/tmp/artefacts/decoder/",
......
from stable_speech import StableSpeechConfig, StableSpeechForCausalLM, StableSpeechForConditionalGeneration, StableSpeechDecoderConfig from parler_tts import ParlerTTSConfig, ParlerTTSForCausalLM, ParlerTTSForConditionalGeneration, ParlerTTSDecoderConfig
from transformers import T5Config, EncodecConfig from transformers import T5Config, EncodecConfig
from transformers import AutoConfig from transformers import AutoConfig
from transformers import AutoConfig, AutoModel from transformers import AutoConfig, AutoModel
from stable_speech import DACConfig, DACModel from parler_tts import DACConfig, DACModel
AutoConfig.register("dac", DACConfig) AutoConfig.register("dac", DACConfig)
AutoModel.register(DACConfig, DACModel) AutoModel.register(DACConfig, DACModel)
...@@ -20,7 +20,7 @@ encodec = AutoConfig.from_pretrained(encodec_version) ...@@ -20,7 +20,7 @@ encodec = AutoConfig.from_pretrained(encodec_version)
encodec_vocab_size = encodec.codebook_size encodec_vocab_size = encodec.codebook_size
decoder_config = StableSpeechDecoderConfig( decoder_config = ParlerTTSDecoderConfig(
vocab_size=encodec_vocab_size+1, vocab_size=encodec_vocab_size+1,
max_position_embeddings=3000, # 30 s = 2580 max_position_embeddings=3000, # 30 s = 2580
num_hidden_layers=12, num_hidden_layers=12,
...@@ -40,13 +40,13 @@ decoder_config = StableSpeechDecoderConfig( ...@@ -40,13 +40,13 @@ decoder_config = StableSpeechDecoderConfig(
) )
decoder = StableSpeechForCausalLM(decoder_config) decoder = ParlerTTSForCausalLM(decoder_config)
decoder.save_pretrained("/raid/yoach/tmp/artefacts/decoder/") decoder.save_pretrained("/raid/yoach/tmp/artefacts/decoder/")
model = StableSpeechForConditionalGeneration.from_sub_models_pretrained( model = ParlerTTSForConditionalGeneration.from_sub_models_pretrained(
text_encoder_pretrained_model_name_or_path=text_model, text_encoder_pretrained_model_name_or_path=text_model,
audio_encoder_pretrained_model_name_or_path=encodec_version, audio_encoder_pretrained_model_name_or_path=encodec_version,
decoder_pretrained_model_name_or_path="/raid/yoach/tmp/artefacts/decoder/", decoder_pretrained_model_name_or_path="/raid/yoach/tmp/artefacts/decoder/",
......
from stable_speech import StableSpeechConfig, StableSpeechForCausalLM, StableSpeechForConditionalGeneration, StableSpeechDecoderConfig from parler_tts import ParlerTTSConfig, ParlerTTSForCausalLM, ParlerTTSForConditionalGeneration, ParlerTTSDecoderConfig
from transformers import T5Config, EncodecConfig from transformers import T5Config, EncodecConfig
from transformers import AutoConfig from transformers import AutoConfig
from transformers import AutoConfig, AutoModel from transformers import AutoConfig, AutoModel
from stable_speech import DACConfig, DACModel from parler_tts import DACConfig, DACModel
AutoConfig.register("dac", DACConfig) AutoConfig.register("dac", DACConfig)
AutoModel.register(DACConfig, DACModel) AutoModel.register(DACConfig, DACModel)
...@@ -20,7 +20,7 @@ encodec = AutoConfig.from_pretrained(encodec_version) ...@@ -20,7 +20,7 @@ encodec = AutoConfig.from_pretrained(encodec_version)
encodec_vocab_size = encodec.codebook_size encodec_vocab_size = encodec.codebook_size
decoder_config = StableSpeechDecoderConfig( decoder_config = ParlerTTSDecoderConfig(
vocab_size=encodec_vocab_size+1, vocab_size=encodec_vocab_size+1,
max_position_embeddings=4096, # 30 s = 2580 max_position_embeddings=4096, # 30 s = 2580
num_hidden_layers=8, num_hidden_layers=8,
...@@ -40,13 +40,13 @@ decoder_config = StableSpeechDecoderConfig( ...@@ -40,13 +40,13 @@ decoder_config = StableSpeechDecoderConfig(
) )
decoder = StableSpeechForCausalLM(decoder_config) decoder = ParlerTTSForCausalLM(decoder_config)
decoder.save_pretrained("/raid/yoach/tmp/artefacts/decoder_small/") decoder.save_pretrained("/raid/yoach/tmp/artefacts/decoder_small/")
model = StableSpeechForConditionalGeneration.from_sub_models_pretrained( model = ParlerTTSForConditionalGeneration.from_sub_models_pretrained(
text_encoder_pretrained_model_name_or_path=text_model, text_encoder_pretrained_model_name_or_path=text_model,
audio_encoder_pretrained_model_name_or_path=encodec_version, audio_encoder_pretrained_model_name_or_path=encodec_version,
decoder_pretrained_model_name_or_path="/raid/yoach/tmp/artefacts/decoder_small/", decoder_pretrained_model_name_or_path="/raid/yoach/tmp/artefacts/decoder_small/",
......
from .configuration_stable_speech import StableSpeechConfig, StableSpeechDecoderConfig from .configuration_parler_tts import ParlerTTSConfig, ParlerTTSDecoderConfig
from .modeling_stable_speech import StableSpeechForCausalLM, StableSpeechForConditionalGeneration, apply_delay_pattern_mask, build_delay_pattern_mask from .modeling_parler_tts import ParlerTTSForCausalLM, ParlerTTSForConditionalGeneration, apply_delay_pattern_mask, build_delay_pattern_mask
from .dac_wrapper import DACConfig, DACModel from .dac_wrapper import DACConfig, DACModel
\ No newline at end of file
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" Stable Speech model configuration""" """ Parler-TTS model configuration"""
from transformers import AutoConfig, logging from transformers import AutoConfig, logging
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
...@@ -21,17 +21,17 @@ from transformers.configuration_utils import PretrainedConfig ...@@ -21,17 +21,17 @@ from transformers.configuration_utils import PretrainedConfig
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
MUSICGEN_PRETRAINED_CONFIG_ARCHIVE_MAP = { MUSICGEN_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"facebook/stable_speech-small": "https://huggingface.co/facebook/stable_speech-small/resolve/main/config.json", "facebook/parler_tts-small": "https://huggingface.co/facebook/parler_tts-small/resolve/main/config.json",
# See all StableSpeech models at https://huggingface.co/models?filter=stable_speech # See all ParlerTTS models at https://huggingface.co/models?filter=parler_tts
} }
class StableSpeechDecoderConfig(PretrainedConfig): class ParlerTTSDecoderConfig(PretrainedConfig):
r""" r"""
This is the configuration class to store the configuration of an [`StableSpeechDecoder`]. It is used to instantiate a This is the configuration class to store the configuration of an [`ParlerTTSDecoder`]. It is used to instantiate a
Stable Speech decoder according to the specified arguments, defining the model architecture. Instantiating a Parler-TTS decoder according to the specified arguments, defining the model architecture. Instantiating a
configuration with the defaults will yield a similar configuration to that of the Stable Speech configuration with the defaults will yield a similar configuration to that of the Parler-TTS
[facebook/stable_speech-small](https://huggingface.co/facebook/stable_speech-small) architecture. [facebook/parler_tts-small](https://huggingface.co/facebook/parler_tts-small) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information. documentation from [`PretrainedConfig`] for more information.
...@@ -39,8 +39,8 @@ class StableSpeechDecoderConfig(PretrainedConfig): ...@@ -39,8 +39,8 @@ class StableSpeechDecoderConfig(PretrainedConfig):
Args: Args:
vocab_size (`int`, *optional*, defaults to 2049): vocab_size (`int`, *optional*, defaults to 2049):
Vocabulary size of the StableSpeechDecoder model. Defines the number of different tokens that can be Vocabulary size of the ParlerTTSDecoder model. Defines the number of different tokens that can be
represented by the `inputs_ids` passed when calling [`StableSpeechDecoder`]. represented by the `inputs_ids` passed when calling [`ParlerTTSDecoder`].
hidden_size (`int`, *optional*, defaults to 1024): hidden_size (`int`, *optional*, defaults to 1024):
Dimensionality of the layers and the pooler layer. Dimensionality of the layers and the pooler layer.
num_hidden_layers (`int`, *optional*, defaults to 24): num_hidden_layers (`int`, *optional*, defaults to 24):
...@@ -76,7 +76,7 @@ class StableSpeechDecoderConfig(PretrainedConfig): ...@@ -76,7 +76,7 @@ class StableSpeechDecoderConfig(PretrainedConfig):
Whether input and output word embeddings should be tied. Whether input and output word embeddings should be tied.
""" """
model_type = "stable_speech_decoder" model_type = "parler_tts_decoder"
keys_to_ignore_at_inference = ["past_key_values"] keys_to_ignore_at_inference = ["past_key_values"]
def __init__( def __init__(
...@@ -127,10 +127,10 @@ class StableSpeechDecoderConfig(PretrainedConfig): ...@@ -127,10 +127,10 @@ class StableSpeechDecoderConfig(PretrainedConfig):
) )
class StableSpeechConfig(PretrainedConfig): class ParlerTTSConfig(PretrainedConfig):
r""" r"""
This is the configuration class to store the configuration of a [`StableSpeechModel`]. It is used to instantiate a This is the configuration class to store the configuration of a [`ParlerTTSModel`]. It is used to instantiate a
Stable Speech model according to the specified arguments, defining the text encoder, audio encoder and Stable Speech decoder Parler-TTS model according to the specified arguments, defining the text encoder, audio encoder and Parler-TTS decoder
configs. configs.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
...@@ -153,24 +153,24 @@ class StableSpeechConfig(PretrainedConfig): ...@@ -153,24 +153,24 @@ class StableSpeechConfig(PretrainedConfig):
```python ```python
>>> from transformers import ( >>> from transformers import (
... StableSpeechConfig, ... ParlerTTSConfig,
... StableSpeechDecoderConfig, ... ParlerTTSDecoderConfig,
... T5Config, ... T5Config,
... EncodecConfig, ... EncodecConfig,
... StableSpeechForConditionalGeneration, ... ParlerTTSForConditionalGeneration,
... ) ... )
>>> # Initializing text encoder, audio encoder, and decoder model configurations >>> # Initializing text encoder, audio encoder, and decoder model configurations
>>> text_encoder_config = T5Config() >>> text_encoder_config = T5Config()
>>> audio_encoder_config = EncodecConfig() >>> audio_encoder_config = EncodecConfig()
>>> decoder_config = StableSpeechDecoderConfig() >>> decoder_config = ParlerTTSDecoderConfig()
>>> configuration = StableSpeechConfig.from_sub_models_config( >>> configuration = ParlerTTSConfig.from_sub_models_config(
... text_encoder_config, audio_encoder_config, decoder_config ... text_encoder_config, audio_encoder_config, decoder_config
... ) ... )
>>> # Initializing a StableSpeechForConditionalGeneration (with random weights) from the facebook/stable_speech-small style configuration >>> # Initializing a ParlerTTSForConditionalGeneration (with random weights) from the facebook/parler_tts-small style configuration
>>> model = StableSpeechForConditionalGeneration(configuration) >>> model = ParlerTTSForConditionalGeneration(configuration)
>>> # Accessing the model configuration >>> # Accessing the model configuration
>>> configuration = model.config >>> configuration = model.config
...@@ -179,14 +179,14 @@ class StableSpeechConfig(PretrainedConfig): ...@@ -179,14 +179,14 @@ class StableSpeechConfig(PretrainedConfig):
>>> config_decoder = model.config.decoder >>> config_decoder = model.config.decoder
>>> # Saving the model, including its configuration >>> # Saving the model, including its configuration
>>> model.save_pretrained("stable_speech-model") >>> model.save_pretrained("parler_tts-model")
>>> # loading model and config from pretrained folder >>> # loading model and config from pretrained folder
>>> stable_speech_config = StableSpeechConfig.from_pretrained("stable_speech-model") >>> parler_tts_config = ParlerTTSConfig.from_pretrained("parler_tts-model")
>>> model = StableSpeechForConditionalGeneration.from_pretrained("stable_speech-model", config=stable_speech_config) >>> model = ParlerTTSForConditionalGeneration.from_pretrained("parler_tts-model", config=parler_tts_config)
```""" ```"""
model_type = "stable_speech" model_type = "parler_tts"
is_composition = True is_composition = True
def __init__(self, vocab_size=1024, **kwargs): def __init__(self, vocab_size=1024, **kwargs):
...@@ -205,7 +205,7 @@ class StableSpeechConfig(PretrainedConfig): ...@@ -205,7 +205,7 @@ class StableSpeechConfig(PretrainedConfig):
self.vocab_size = vocab_size self.vocab_size = vocab_size
self.text_encoder = AutoConfig.for_model(text_encoder_model_type, **text_encoder_config) self.text_encoder = AutoConfig.for_model(text_encoder_model_type, **text_encoder_config)
self.audio_encoder = AutoConfig.for_model(audio_encoder_model_type, **audio_encoder_config) self.audio_encoder = AutoConfig.for_model(audio_encoder_model_type, **audio_encoder_config)
self.decoder = StableSpeechDecoderConfig(**decoder_config) self.decoder = ParlerTTSDecoderConfig(**decoder_config)
self.is_encoder_decoder = True self.is_encoder_decoder = True
@classmethod @classmethod
...@@ -213,15 +213,15 @@ class StableSpeechConfig(PretrainedConfig): ...@@ -213,15 +213,15 @@ class StableSpeechConfig(PretrainedConfig):
cls, cls,
text_encoder_config: PretrainedConfig, text_encoder_config: PretrainedConfig,
audio_encoder_config: PretrainedConfig, audio_encoder_config: PretrainedConfig,
decoder_config: StableSpeechDecoderConfig, decoder_config: ParlerTTSDecoderConfig,
**kwargs, **kwargs,
): ):
r""" r"""
Instantiate a [`StableSpeechConfig`] (or a derived class) from text encoder, audio encoder and decoder Instantiate a [`ParlerTTSConfig`] (or a derived class) from text encoder, audio encoder and decoder
configurations. configurations.
Returns: Returns:
[`StableSpeechConfig`]: An instance of a configuration object [`ParlerTTSConfig`]: An instance of a configuration object
""" """
return cls( return cls(
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" PyTorch StableSpeech model.""" """ PyTorch ParlerTTS model."""
import copy import copy
import inspect import inspect
import math import math
...@@ -44,7 +44,7 @@ from transformers.utils import ( ...@@ -44,7 +44,7 @@ from transformers.utils import (
replace_return_docstrings, replace_return_docstrings,
) )
from .configuration_stable_speech import StableSpeechConfig, StableSpeechDecoderConfig from .configuration_parler_tts import ParlerTTSConfig, ParlerTTSDecoderConfig
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -52,12 +52,12 @@ if TYPE_CHECKING: ...@@ -52,12 +52,12 @@ if TYPE_CHECKING:
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "StableSpeechConfig" _CONFIG_FOR_DOC = "ParlerTTSConfig"
_CHECKPOINT_FOR_DOC = "facebook/stable_speech-small" _CHECKPOINT_FOR_DOC = "facebook/parler_tts-small"
MUSICGEN_PRETRAINED_MODEL_ARCHIVE_LIST = [ MUSICGEN_PRETRAINED_MODEL_ARCHIVE_LIST = [
"facebook/stable_speech-small", "facebook/parler_tts-small",
# See all StableSpeech models at https://huggingface.co/models?filter=stable_speech # See all ParlerTTS models at https://huggingface.co/models?filter=parler_tts
] ]
def apply_delay_pattern_mask(input_ids, decoder_pad_token_mask): def apply_delay_pattern_mask(input_ids, decoder_pad_token_mask):
...@@ -133,7 +133,7 @@ def build_delay_pattern_mask(input_ids: torch.LongTensor, bos_token_id: int, pad ...@@ -133,7 +133,7 @@ def build_delay_pattern_mask(input_ids: torch.LongTensor, bos_token_id: int, pad
return input_ids, pattern_mask return input_ids, pattern_mask
@dataclass @dataclass
class StableSpeechUnconditionalInput(ModelOutput): class ParlerTTSUnconditionalInput(ModelOutput):
""" """
Args: Args:
encoder_outputs (`Tuple[torch.FloatTensor]` of length 1, with tensor shape `(batch_size, sequence_length, hidden_size)`): encoder_outputs (`Tuple[torch.FloatTensor]` of length 1, with tensor shape `(batch_size, sequence_length, hidden_size)`):
...@@ -170,8 +170,8 @@ def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start ...@@ -170,8 +170,8 @@ def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start
return shifted_input_ids return shifted_input_ids
# Copied from transformers.models.musicgen.modeling_musicgen.MusicgenSinusoidalPositionalEmbedding with Musicgen->StableSpeech # Copied from transformers.models.musicgen.modeling_musicgen.MusicgenSinusoidalPositionalEmbedding with Musicgen->ParlerTTS
class StableSpeechSinusoidalPositionalEmbedding(nn.Module): class ParlerTTSSinusoidalPositionalEmbedding(nn.Module):
"""This module produces sinusoidal positional embeddings of any length.""" """This module produces sinusoidal positional embeddings of any length."""
def __init__(self, num_positions: int, embedding_dim: int): def __init__(self, num_positions: int, embedding_dim: int):
...@@ -217,8 +217,8 @@ class StableSpeechSinusoidalPositionalEmbedding(nn.Module): ...@@ -217,8 +217,8 @@ class StableSpeechSinusoidalPositionalEmbedding(nn.Module):
return self.weights.index_select(0, position_ids.view(-1)).detach() return self.weights.index_select(0, position_ids.view(-1)).detach()
# Copied from transformers.models.musicgen.modeling_musicgen.MusicgenAttention with Musicgen->StableSpeech # Copied from transformers.models.musicgen.modeling_musicgen.MusicgenAttention with Musicgen->ParlerTTS
class StableSpeechAttention(nn.Module): class ParlerTTSAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper""" """Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__( def __init__(
...@@ -229,7 +229,7 @@ class StableSpeechAttention(nn.Module): ...@@ -229,7 +229,7 @@ class StableSpeechAttention(nn.Module):
is_decoder: bool = False, is_decoder: bool = False,
bias: bool = True, bias: bool = True,
is_causal: bool = False, is_causal: bool = False,
config: Optional[StableSpeechConfig] = None, config: Optional[ParlerTTSConfig] = None,
): ):
super().__init__() super().__init__()
self.embed_dim = embed_dim self.embed_dim = embed_dim
...@@ -376,13 +376,13 @@ class StableSpeechAttention(nn.Module): ...@@ -376,13 +376,13 @@ class StableSpeechAttention(nn.Module):
return attn_output, attn_weights_reshaped, past_key_value return attn_output, attn_weights_reshaped, past_key_value
# Copied from transformers.models.musicgen.modeling_musicgen.MusicgenDecoderLayer with Musicgen->StableSpeech # Copied from transformers.models.musicgen.modeling_musicgen.MusicgenDecoderLayer with Musicgen->ParlerTTS
class StableSpeechDecoderLayer(nn.Module): class ParlerTTSDecoderLayer(nn.Module):
def __init__(self, config: StableSpeechDecoderConfig): def __init__(self, config: ParlerTTSDecoderConfig):
super().__init__() super().__init__()
self.embed_dim = config.hidden_size self.embed_dim = config.hidden_size
self.self_attn = StableSpeechAttention( self.self_attn = ParlerTTSAttention(
embed_dim=self.embed_dim, embed_dim=self.embed_dim,
num_heads=config.num_attention_heads, num_heads=config.num_attention_heads,
dropout=config.attention_dropout, dropout=config.attention_dropout,
...@@ -394,7 +394,7 @@ class StableSpeechDecoderLayer(nn.Module): ...@@ -394,7 +394,7 @@ class StableSpeechDecoderLayer(nn.Module):
self.activation_dropout = config.activation_dropout self.activation_dropout = config.activation_dropout
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.encoder_attn = StableSpeechAttention( self.encoder_attn = ParlerTTSAttention(
self.embed_dim, self.embed_dim,
config.num_attention_heads, config.num_attention_heads,
dropout=config.attention_dropout, dropout=config.attention_dropout,
...@@ -496,17 +496,17 @@ class StableSpeechDecoderLayer(nn.Module): ...@@ -496,17 +496,17 @@ class StableSpeechDecoderLayer(nn.Module):
return outputs return outputs
# Copied from transformers.models.musicgen.modeling_musicgen.MusicgenPreTrainedModel with Musicgen->StableSpeech # Copied from transformers.models.musicgen.modeling_musicgen.MusicgenPreTrainedModel with Musicgen->ParlerTTS
class StableSpeechPreTrainedModel(PreTrainedModel): class ParlerTTSPreTrainedModel(PreTrainedModel):
""" """
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models. models.
""" """
config_class = StableSpeechDecoderConfig config_class = ParlerTTSDecoderConfig
base_model_prefix = "model" base_model_prefix = "model"
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_no_split_modules = ["StableSpeechDecoderLayer", "StableSpeechAttention"] _no_split_modules = ["ParlerTTSDecoderLayer", "ParlerTTSAttention"]
def _init_weights(self, module): def _init_weights(self, module):
std = self.config.initializer_factor std = self.config.initializer_factor
...@@ -522,7 +522,7 @@ class StableSpeechPreTrainedModel(PreTrainedModel): ...@@ -522,7 +522,7 @@ class StableSpeechPreTrainedModel(PreTrainedModel):
MUSICGEN_START_DOCSTRING = r""" MUSICGEN_START_DOCSTRING = r"""
The StableSpeech model was proposed in [Simple and Controllable Music Generation](https://arxiv.org/abs/2306.05284) by The ParlerTTS model was proposed in [Simple and Controllable Music Generation](https://arxiv.org/abs/2306.05284) by
Jade Copet, Felix Kreuk, Itai Gat, Tal Remez, David Kant, Gabriel Synnaeve, Yossi Adi, Alexandre Défossez. It is an Jade Copet, Felix Kreuk, Itai Gat, Tal Remez, David Kant, Gabriel Synnaeve, Yossi Adi, Alexandre Défossez. It is an
encoder decoder transformer trained on the task of conditional music generation encoder decoder transformer trained on the task of conditional music generation
...@@ -535,7 +535,7 @@ MUSICGEN_START_DOCSTRING = r""" ...@@ -535,7 +535,7 @@ MUSICGEN_START_DOCSTRING = r"""
and behavior. and behavior.
Parameters: Parameters:
config ([`StableSpeechConfig`]): Model configuration class with all the parameters of the model. config ([`ParlerTTSConfig`]): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
""" """
...@@ -716,13 +716,13 @@ MUSICGEN_DECODER_INPUTS_DOCSTRING = r""" ...@@ -716,13 +716,13 @@ MUSICGEN_DECODER_INPUTS_DOCSTRING = r"""
""" """
# Copied from transformers.models.musicgen.modeling_musicgen.MusicgenDecoder with Musicgen->StableSpeech # Copied from transformers.models.musicgen.modeling_musicgen.MusicgenDecoder with Musicgen->ParlerTTS
class StableSpeechDecoder(StableSpeechPreTrainedModel): class ParlerTTSDecoder(ParlerTTSPreTrainedModel):
""" """
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`StableSpeechDecoderLayer`] Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`ParlerTTSDecoderLayer`]
""" """
def __init__(self, config: StableSpeechDecoderConfig): def __init__(self, config: ParlerTTSDecoderConfig):
super().__init__(config) super().__init__(config)
self.dropout = config.dropout self.dropout = config.dropout
self.layerdrop = config.layerdrop self.layerdrop = config.layerdrop
...@@ -737,12 +737,12 @@ class StableSpeechDecoder(StableSpeechPreTrainedModel): ...@@ -737,12 +737,12 @@ class StableSpeechDecoder(StableSpeechPreTrainedModel):
[nn.Embedding(embed_dim, config.hidden_size) for _ in range(config.num_codebooks)] [nn.Embedding(embed_dim, config.hidden_size) for _ in range(config.num_codebooks)]
) )
self.embed_positions = StableSpeechSinusoidalPositionalEmbedding( self.embed_positions = ParlerTTSSinusoidalPositionalEmbedding(
config.max_position_embeddings, config.max_position_embeddings,
config.hidden_size, config.hidden_size,
) )
self.layers = nn.ModuleList([StableSpeechDecoderLayer(config) for _ in range(config.num_hidden_layers)]) self.layers = nn.ModuleList([ParlerTTSDecoderLayer(config) for _ in range(config.num_hidden_layers)])
self.layer_norm = nn.LayerNorm(config.hidden_size) self.layer_norm = nn.LayerNorm(config.hidden_size)
self.gradient_checkpointing = False self.gradient_checkpointing = False
...@@ -930,14 +930,14 @@ class StableSpeechDecoder(StableSpeechPreTrainedModel): ...@@ -930,14 +930,14 @@ class StableSpeechDecoder(StableSpeechPreTrainedModel):
@add_start_docstrings( @add_start_docstrings(
"The bare StableSpeech decoder model outputting raw hidden-states without any specific head on top.", "The bare ParlerTTS decoder model outputting raw hidden-states without any specific head on top.",
MUSICGEN_START_DOCSTRING, MUSICGEN_START_DOCSTRING,
) )
# Copied from transformers.models.musicgen.modeling_musicgen.MusicgenModel with Musicgen->StableSpeech # Copied from transformers.models.musicgen.modeling_musicgen.MusicgenModel with Musicgen->ParlerTTS
class StableSpeechModel(StableSpeechPreTrainedModel): class ParlerTTSModel(ParlerTTSPreTrainedModel):
def __init__(self, config: StableSpeechDecoderConfig): def __init__(self, config: ParlerTTSDecoderConfig):
super().__init__(config) super().__init__(config)
self.decoder = StableSpeechDecoder(config) self.decoder = ParlerTTSDecoder(config)
# Initialize weights and apply final processing # Initialize weights and apply final processing
self.post_init() self.post_init()
...@@ -1006,15 +1006,15 @@ class StableSpeechModel(StableSpeechPreTrainedModel): ...@@ -1006,15 +1006,15 @@ class StableSpeechModel(StableSpeechPreTrainedModel):
@add_start_docstrings( @add_start_docstrings(
"The Stable Speech decoder model with a language modelling head on top.", "The Parler-TTS decoder model with a language modelling head on top.",
MUSICGEN_START_DOCSTRING, MUSICGEN_START_DOCSTRING,
) )
# Copied from transformers.models.musicgen.modeling_musicgen.MusicgenForCausalLM with Musicgen->StableSpeech # Copied from transformers.models.musicgen.modeling_musicgen.MusicgenForCausalLM with Musicgen->ParlerTTS
class StableSpeechForCausalLM(StableSpeechPreTrainedModel): class ParlerTTSForCausalLM(ParlerTTSPreTrainedModel):
def __init__(self, config: StableSpeechDecoderConfig): def __init__(self, config: ParlerTTSDecoderConfig):
super().__init__(config) super().__init__(config)
self.model = StableSpeechModel(config) self.model = ParlerTTSModel(config)
self.num_codebooks = config.num_codebooks self.num_codebooks = config.num_codebooks
self.lm_heads = nn.ModuleList( self.lm_heads = nn.ModuleList(
...@@ -1379,7 +1379,7 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel): ...@@ -1379,7 +1379,7 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel):
) )
# 6. Prepare `input_ids` which will be used for auto-regressive generation # 6. Prepare `input_ids` which will be used for auto-regressive generation
# Build the delay pattern mask for offsetting each codebook prediction by 1 (this behaviour is specific to Stable Speech) # Build the delay pattern mask for offsetting each codebook prediction by 1 (this behaviour is specific to Parler-TTS)
input_ids, delay_pattern_mask = self.build_delay_pattern_mask( input_ids, delay_pattern_mask = self.build_delay_pattern_mask(
input_ids, input_ids,
bos_token_id=generation_config.bos_token_id, bos_token_id=generation_config.bos_token_id,
...@@ -1498,29 +1498,29 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel): ...@@ -1498,29 +1498,29 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel):
@add_start_docstrings( @add_start_docstrings(
"The composite Stable Speech model with a text encoder, audio encoder and StableSpeech decoder, " "The composite Parler-TTS model with a text encoder, audio encoder and ParlerTTS decoder, "
"for music generation tasks with one or both of text and audio prompts.", "for music generation tasks with one or both of text and audio prompts.",
MUSICGEN_START_DOCSTRING, MUSICGEN_START_DOCSTRING,
) )
class StableSpeechForConditionalGeneration(PreTrainedModel): class ParlerTTSForConditionalGeneration(PreTrainedModel):
config_class = StableSpeechConfig config_class = ParlerTTSConfig
base_model_prefix = "encoder_decoder" base_model_prefix = "encoder_decoder"
main_input_name = "input_ids" main_input_name = "input_ids"
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
def __init__( def __init__(
self, self,
config: Optional[StableSpeechConfig] = None, config: Optional[ParlerTTSConfig] = None,
text_encoder: Optional[PreTrainedModel] = None, text_encoder: Optional[PreTrainedModel] = None,
audio_encoder: Optional[PreTrainedModel] = None, audio_encoder: Optional[PreTrainedModel] = None,
decoder: Optional[StableSpeechForCausalLM] = None, decoder: Optional[ParlerTTSForCausalLM] = None,
): ):
if config is None and (text_encoder is None or audio_encoder is None or decoder is None): if config is None and (text_encoder is None or audio_encoder is None or decoder is None):
raise ValueError( raise ValueError(
"Either a configuration has to be provided, or all three of text encoder, audio encoder and Stable Speech decoder." "Either a configuration has to be provided, or all three of text encoder, audio encoder and Parler-TTS decoder."
) )
if config is None: if config is None:
config = StableSpeechConfig.from_sub_models_config( config = ParlerTTSConfig.from_sub_models_config(
text_encoder.config, audio_encoder.config, decoder.config text_encoder.config, audio_encoder.config, decoder.config
) )
else: else:
...@@ -1530,7 +1530,7 @@ class StableSpeechForConditionalGeneration(PreTrainedModel): ...@@ -1530,7 +1530,7 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
if config.decoder.cross_attention_hidden_size is not None: if config.decoder.cross_attention_hidden_size is not None:
if config.decoder.cross_attention_hidden_size != config.text_encoder.hidden_size: if config.decoder.cross_attention_hidden_size != config.text_encoder.hidden_size:
raise ValueError( raise ValueError(
"If `cross_attention_hidden_size` is specified in the Stable Speech decoder's configuration, it has to be equal" "If `cross_attention_hidden_size` is specified in the Parler-TTS decoder's configuration, it has to be equal"
f" to the text encoder's `hidden_size`. Got {config.decoder.cross_attention_hidden_size} for" f" to the text encoder's `hidden_size`. Got {config.decoder.cross_attention_hidden_size} for"
f" `config.decoder.cross_attention_hidden_size` and {config.text_encoder.hidden_size} for" f" `config.decoder.cross_attention_hidden_size` and {config.text_encoder.hidden_size} for"
" `config.text_encoder.hidden_size`." " `config.text_encoder.hidden_size`."
...@@ -1550,7 +1550,7 @@ class StableSpeechForConditionalGeneration(PreTrainedModel): ...@@ -1550,7 +1550,7 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
audio_encoder = AutoModel.from_config(config.audio_encoder) audio_encoder = AutoModel.from_config(config.audio_encoder)
if decoder is None: if decoder is None:
decoder = StableSpeechForCausalLM(config.decoder) decoder = ParlerTTSForCausalLM(config.decoder)
self.text_encoder = text_encoder self.text_encoder = text_encoder
self.audio_encoder = audio_encoder self.audio_encoder = audio_encoder
...@@ -1652,15 +1652,15 @@ class StableSpeechForConditionalGeneration(PreTrainedModel): ...@@ -1652,15 +1652,15 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
Example: Example:
```python ```python
>>> from transformers import StableSpeechForConditionalGeneration >>> from transformers import ParlerTTSForConditionalGeneration
>>> model = StableSpeechForConditionalGeneration.from_pretrained("facebook/stable_speech-small") >>> model = ParlerTTSForConditionalGeneration.from_pretrained("facebook/parler_tts-small")
```""" ```"""
# At the moment fast initialization is not supported for composite models # At the moment fast initialization is not supported for composite models
if kwargs.get("_fast_init", False): if kwargs.get("_fast_init", False):
logger.warning( logger.warning(
"Fast initialization is currently not supported for StableSpeechForConditionalGeneration. " "Fast initialization is currently not supported for ParlerTTSForConditionalGeneration. "
"Falling back to slow initialization..." "Falling back to slow initialization..."
) )
kwargs["_fast_init"] = False kwargs["_fast_init"] = False
...@@ -1677,7 +1677,7 @@ class StableSpeechForConditionalGeneration(PreTrainedModel): ...@@ -1677,7 +1677,7 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
**kwargs, **kwargs,
) -> PreTrainedModel: ) -> PreTrainedModel:
r""" r"""
Instantiate a text encoder, an audio encoder, and a Stable Speech decoder from one, two or three base classes of the Instantiate a text encoder, an audio encoder, and a Parler-TTS decoder from one, two or three base classes of the
library from pretrained model checkpoints. library from pretrained model checkpoints.
...@@ -1708,7 +1708,7 @@ class StableSpeechForConditionalGeneration(PreTrainedModel): ...@@ -1708,7 +1708,7 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
- A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
Valid model ids can be located at the root-level, like `gpt2`, or namespaced under a user or Valid model ids can be located at the root-level, like `gpt2`, or namespaced under a user or
organization name, like `facebook/stable_speech-small`. organization name, like `facebook/parler_tts-small`.
- A path to a *directory* containing model weights saved using - A path to a *directory* containing model weights saved using
[`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`. [`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
...@@ -1731,18 +1731,18 @@ class StableSpeechForConditionalGeneration(PreTrainedModel): ...@@ -1731,18 +1731,18 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
Example: Example:
```python ```python
>>> from transformers import StableSpeechForConditionalGeneration >>> from transformers import ParlerTTSForConditionalGeneration
>>> # initialize a stable_speech model from a t5 text encoder, encodec audio encoder, and stable_speech decoder >>> # initialize a parler_tts model from a t5 text encoder, encodec audio encoder, and parler_tts decoder
>>> model = StableSpeechForConditionalGeneration.from_sub_models_pretrained( >>> model = ParlerTTSForConditionalGeneration.from_sub_models_pretrained(
... text_encoder_pretrained_model_name_or_path="t5-base", ... text_encoder_pretrained_model_name_or_path="t5-base",
... audio_encoder_pretrained_model_name_or_path="facebook/encodec_24khz", ... audio_encoder_pretrained_model_name_or_path="facebook/encodec_24khz",
... decoder_pretrained_model_name_or_path="facebook/stable_speech-small", ... decoder_pretrained_model_name_or_path="facebook/parler_tts-small",
... ) ... )
>>> # saving model after fine-tuning >>> # saving model after fine-tuning
>>> model.save_pretrained("./stable_speech-ft") >>> model.save_pretrained("./parler_tts-ft")
>>> # load fine-tuned model >>> # load fine-tuned model
>>> model = StableSpeechForConditionalGeneration.from_pretrained("./stable_speech-ft") >>> model = ParlerTTSForConditionalGeneration.from_pretrained("./parler_tts-ft")
```""" ```"""
kwargs_text_encoder = { kwargs_text_encoder = {
...@@ -1836,11 +1836,11 @@ class StableSpeechForConditionalGeneration(PreTrainedModel): ...@@ -1836,11 +1836,11 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
if "config" not in kwargs_decoder: if "config" not in kwargs_decoder:
# TODO: reput AutoConfig once added to transformers # TODO: reput AutoConfig once added to transformers
decoder_config, kwargs_decoder = StableSpeechDecoderConfig.from_pretrained( decoder_config, kwargs_decoder = ParlerTTSDecoderConfig.from_pretrained(
decoder_pretrained_model_name_or_path, **kwargs_decoder, return_unused_kwargs=True decoder_pretrained_model_name_or_path, **kwargs_decoder, return_unused_kwargs=True
) )
if isinstance(decoder_config, StableSpeechConfig): if isinstance(decoder_config, ParlerTTSConfig):
decoder_config = decoder_config.decoder decoder_config = decoder_config.decoder
if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False: if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False:
...@@ -1863,10 +1863,10 @@ class StableSpeechForConditionalGeneration(PreTrainedModel): ...@@ -1863,10 +1863,10 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
"`decoder_config` to `.from_sub_models_pretrained(...)`" "`decoder_config` to `.from_sub_models_pretrained(...)`"
) )
decoder = StableSpeechForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder) decoder = ParlerTTSForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder)
# instantiate config with corresponding kwargs # instantiate config with corresponding kwargs
config = StableSpeechConfig.from_sub_models_config( config = ParlerTTSConfig.from_sub_models_config(
text_encoder.config, audio_encoder.config, decoder.config, **kwargs text_encoder.config, audio_encoder.config, decoder.config, **kwargs
) )
return cls(text_encoder=text_encoder, audio_encoder=audio_encoder, decoder=decoder, config=config) return cls(text_encoder=text_encoder, audio_encoder=audio_encoder, decoder=decoder, config=config)
...@@ -1900,11 +1900,11 @@ class StableSpeechForConditionalGeneration(PreTrainedModel): ...@@ -1900,11 +1900,11 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
Examples: Examples:
```python ```python
>>> from transformers import AutoProcessor, StableSpeechForConditionalGeneration >>> from transformers import AutoProcessor, ParlerTTSForConditionalGeneration
>>> import torch >>> import torch
>>> processor = AutoProcessor.from_pretrained("facebook/stable_speech-small") >>> processor = AutoProcessor.from_pretrained("facebook/parler_tts-small")
>>> model = StableSpeechForConditionalGeneration.from_pretrained("facebook/stable_speech-small") >>> model = ParlerTTSForConditionalGeneration.from_pretrained("facebook/parler_tts-small")
>>> inputs = processor( >>> inputs = processor(
... text=["80s pop track with bassy drums and synth", "90s rock song with loud guitars and heavy drums"], ... text=["80s pop track with bassy drums and synth", "90s rock song with loud guitars and heavy drums"],
...@@ -2479,7 +2479,7 @@ class StableSpeechForConditionalGeneration(PreTrainedModel): ...@@ -2479,7 +2479,7 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
" increasing `max_new_tokens`." " increasing `max_new_tokens`."
) )
# build the delay pattern mask for offsetting each codebook prediction by 1 (this behaviour is specific to Stable Speech) # build the delay pattern mask for offsetting each codebook prediction by 1 (this behaviour is specific to Parler-TTS)
input_ids, decoder_delay_pattern_mask = self.decoder.build_delay_pattern_mask( input_ids, decoder_delay_pattern_mask = self.decoder.build_delay_pattern_mask(
input_ids, input_ids,
bos_token_id=generation_config.bos_token_id, bos_token_id=generation_config.bos_token_id,
...@@ -2649,9 +2649,9 @@ class StableSpeechForConditionalGeneration(PreTrainedModel): ...@@ -2649,9 +2649,9 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
Example: Example:
```python ```python
>>> from transformers import StableSpeechForConditionalGeneration >>> from transformers import ParlerTTSForConditionalGeneration
>>> model = StableSpeechForConditionalGeneration.from_pretrained("facebook/stable_speech-small") >>> model = ParlerTTSForConditionalGeneration.from_pretrained("facebook/parler_tts-small")
>>> # get the unconditional (or 'null') inputs for the model >>> # get the unconditional (or 'null') inputs for the model
>>> unconditional_inputs = model.get_unconditional_inputs(num_samples=1) >>> unconditional_inputs = model.get_unconditional_inputs(num_samples=1)
...@@ -2663,7 +2663,7 @@ class StableSpeechForConditionalGeneration(PreTrainedModel): ...@@ -2663,7 +2663,7 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
attention_mask = torch.zeros((num_samples, 1), device=self.device, dtype=torch.long) attention_mask = torch.zeros((num_samples, 1), device=self.device, dtype=torch.long)
return StableSpeechUnconditionalInput( return ParlerTTSUnconditionalInput(
encoder_outputs=(last_hidden_state,), encoder_outputs=(last_hidden_state,),
attention_mask=attention_mask, attention_mask=attention_mask,
guidance_scale=1.0, guidance_scale=1.0,
......
...@@ -4,7 +4,7 @@ import dac ...@@ -4,7 +4,7 @@ import dac
model_path = dac.utils.download(model_type="44khz") model_path = dac.utils.download(model_type="44khz")
model = dac.DAC.load(model_path) model = dac.DAC.load(model_path)
from stable_speech import DACConfig, DACModel from parler_tts import DACConfig, DACModel
hf_dac = DACModel(DACConfig()) hf_dac = DACModel(DACConfig())
hf_dac.model.load_state_dict(model.state_dict()) hf_dac.model.load_state_dict(model.state_dict())
......
...@@ -63,7 +63,7 @@ from transformers.utils import check_min_version, send_example_telemetry ...@@ -63,7 +63,7 @@ from transformers.utils import check_min_version, send_example_telemetry
from transformers.utils.versions import require_version from transformers.utils.versions import require_version
from transformers.integrations import is_wandb_available from transformers.integrations import is_wandb_available
from transformers import AutoConfig, AutoModel from transformers import AutoConfig, AutoModel
from stable_speech import DACConfig, DACModel from parler_tts import DACConfig, DACModel
from transformers.modeling_outputs import BaseModelOutput from transformers.modeling_outputs import BaseModelOutput
AutoConfig.register("dac", DACConfig) AutoConfig.register("dac", DACConfig)
AutoModel.register(DACConfig, DACModel) AutoModel.register(DACConfig, DACModel)
...@@ -73,7 +73,7 @@ from accelerate import Accelerator ...@@ -73,7 +73,7 @@ from accelerate import Accelerator
from accelerate.utils import set_seed, AutocastKwargs, InitProcessGroupKwargs, TorchDynamoPlugin from accelerate.utils import set_seed, AutocastKwargs, InitProcessGroupKwargs, TorchDynamoPlugin
from accelerate.utils.memory import release_memory from accelerate.utils.memory import release_memory
from stable_speech import StableSpeechForConditionalGeneration, StableSpeechConfig, apply_delay_pattern_mask, build_delay_pattern_mask from parler_tts import ParlerTTSForConditionalGeneration, ParlerTTSConfig, apply_delay_pattern_mask, build_delay_pattern_mask
if is_wandb_available(): if is_wandb_available():
from wandb import Audio from wandb import Audio
...@@ -455,7 +455,7 @@ class DataTrainingArguments: ...@@ -455,7 +455,7 @@ class DataTrainingArguments:
} }
) )
wandb_project: str = field( wandb_project: str = field(
default="stable-speech", default="parler-speech",
metadata={"help": "The name of the wandb project."}, metadata={"help": "The name of the wandb project."},
) )
save_to_disk: str = field( save_to_disk: str = field(
...@@ -480,7 +480,7 @@ class DataTrainingArguments: ...@@ -480,7 +480,7 @@ class DataTrainingArguments:
) )
@dataclass @dataclass
class StableSpeechTrainingArguments(Seq2SeqTrainingArguments): class ParlerTTSTrainingArguments(Seq2SeqTrainingArguments):
dtype: Optional[str] = field( dtype: Optional[str] = field(
default="float32", default="float32",
metadata={ metadata={
...@@ -525,7 +525,7 @@ class DataCollatorEncodecWithPadding: ...@@ -525,7 +525,7 @@ class DataCollatorEncodecWithPadding:
@dataclass @dataclass
class DataCollatorStableSpeechWithPadding: class DataCollatorParlerTTSWithPadding:
""" """
Data collator that will dynamically pad the inputs received. Data collator that will dynamically pad the inputs received.
Args: Args:
...@@ -716,14 +716,14 @@ def load_multiple_datasets( ...@@ -716,14 +716,14 @@ def load_multiple_datasets(
# TODO(YL): I forgot to create unique ids for MLS english. # TODO(YL): I forgot to create unique ids for MLS english.
# To iterate faster, I bypass the original id check and do another one. - Done once because assuming it won't change next time # To iterate faster, I bypass the original id check and do another one. - Done once because assuming it won't change next time
# if dataset_dict["name"] == "stable-speech/mls_eng_10k": # if dataset_dict["name"] == "parler-tts/mls_eng_10k":
# def concat_ids(book_id, speaker_id, begin_time): # def concat_ids(book_id, speaker_id, begin_time):
# return {"id": f"{book_id}_{speaker_id}_{str(begin_time).replace('.', '_')}"} # return {"id": f"{book_id}_{speaker_id}_{str(begin_time).replace('.', '_')}"}
# dataset = dataset.map(concat_ids, input_columns=["book_id", "speaker_id", "begin_time"], num_proc=24) # dataset = dataset.map(concat_ids, input_columns=["book_id", "speaker_id", "begin_time"], num_proc=24)
# metadata_dataset = metadata_dataset.map(concat_ids, input_columns=["book_id", "speaker_id", "begin_time"], num_proc=24) # metadata_dataset = metadata_dataset.map(concat_ids, input_columns=["book_id", "speaker_id", "begin_time"], num_proc=24)
# metadata_dataset = metadata_dataset.rename_column(id_column_name, f"metadata_{id_column_name}") # metadata_dataset = metadata_dataset.rename_column(id_column_name, f"metadata_{id_column_name}")
if dataset_dict["name"] != "stable-speech/mls_eng_10k": if dataset_dict["name"] != "parler-tts/mls_eng_10k":
if id_column_name is not None and id_column_name not in dataset.column_names: if id_column_name is not None and id_column_name not in dataset.column_names:
raise ValueError( raise ValueError(
f"id_column_name={id_column_name} but has not been found in the dataset columns" f"id_column_name={id_column_name} but has not been found in the dataset columns"
...@@ -751,7 +751,7 @@ def load_multiple_datasets( ...@@ -751,7 +751,7 @@ def load_multiple_datasets(
dataset = concatenate_datasets([dataset, metadata_dataset], axis=1) dataset = concatenate_datasets([dataset, metadata_dataset], axis=1)
if id_column_name is not None and dataset_dict["name"] != "stable-speech/mls_eng_10k": if id_column_name is not None and dataset_dict["name"] != "parler-tts/mls_eng_10k":
if len(dataset.filter(lambda id1, id2: id1!=id2, input_columns=[id_column_name, f"metadata_{id_column_name}"])) != 0: if len(dataset.filter(lambda id1, id2: id1!=id2, input_columns=[id_column_name, f"metadata_{id_column_name}"])) != 0:
raise ValueError(f"Concatenate didn't work. Some ids don't correspond on dataset {dataset_dict['name']}") raise ValueError(f"Concatenate didn't work. Some ids don't correspond on dataset {dataset_dict['name']}")
...@@ -785,7 +785,7 @@ def main(): ...@@ -785,7 +785,7 @@ def main():
# or by passing the --help flag to this script. # or by passing the --help flag to this script.
# We now keep distinct sets of args, for a cleaner separation of concerns. # We now keep distinct sets of args, for a cleaner separation of concerns.
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, StableSpeechTrainingArguments)) parser = HfArgumentParser((ModelArguments, DataTrainingArguments, ParlerTTSTrainingArguments))
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
# If we pass only one argument to the script and it's the path to a json file, # If we pass only one argument to the script and it's the path to a json file,
# let's parse it to get our arguments. # let's parse it to get our arguments.
...@@ -795,7 +795,7 @@ def main(): ...@@ -795,7 +795,7 @@ def main():
# Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
# information sent is the one passed as arguments along with your Python/PyTorch versions. # information sent is the one passed as arguments along with your Python/PyTorch versions.
send_example_telemetry("run_stable_speech", model_args, data_args) send_example_telemetry("run_parler_tts", model_args, data_args)
if training_args.dtype == "float16": if training_args.dtype == "float16":
mixed_precision = "fp16" mixed_precision = "fp16"
...@@ -996,7 +996,7 @@ def main(): ...@@ -996,7 +996,7 @@ def main():
# 3. Next, let's load the config. # 3. Next, let's load the config.
# TODO(YL): add the option to create the config from scratch # TODO(YL): add the option to create the config from scratch
config = StableSpeechConfig.from_pretrained( config = ParlerTTSConfig.from_pretrained(
model_args.model_name_or_path, model_args.model_name_or_path,
cache_dir=model_args.cache_dir, cache_dir=model_args.cache_dir,
token=data_args.token, token=data_args.token,
...@@ -1011,7 +1011,7 @@ def main(): ...@@ -1011,7 +1011,7 @@ def main():
}) })
# create model + TODO(YL): not from_pretrained probably # create model + TODO(YL): not from_pretrained probably
model = StableSpeechForConditionalGeneration.from_pretrained( model = ParlerTTSForConditionalGeneration.from_pretrained(
model_args.model_name_or_path, model_args.model_name_or_path,
cache_dir=model_args.cache_dir, cache_dir=model_args.cache_dir,
config=config, config=config,
...@@ -1353,7 +1353,7 @@ def main(): ...@@ -1353,7 +1353,7 @@ def main():
) )
# Instantiate custom data collator # Instantiate custom data collator
data_collator = DataCollatorStableSpeechWithPadding( data_collator = DataCollatorParlerTTSWithPadding(
audio_feature_extractor=feature_extractor, feature_extractor_input_name=feature_extractor_input_name, prompt_tokenizer=prompt_tokenizer, description_tokenizer=description_tokenizer, pad_to_multiple_of=data_args.pad_to_multiple_of, audio_feature_extractor=feature_extractor, feature_extractor_input_name=feature_extractor_input_name, prompt_tokenizer=prompt_tokenizer, description_tokenizer=description_tokenizer, pad_to_multiple_of=data_args.pad_to_multiple_of,
padding=padding, prompt_max_length=data_args.max_prompt_token_length, description_max_length=data_args.max_description_token_length, audio_max_length = audio_max_length padding=padding, prompt_max_length=data_args.max_prompt_token_length, description_max_length=data_args.max_description_token_length, audio_max_length = audio_max_length
) )
......
...@@ -41,7 +41,7 @@ with open(os.path.join(here, "README.md"), encoding="utf-8") as f: ...@@ -41,7 +41,7 @@ with open(os.path.join(here, "README.md"), encoding="utf-8") as f:
long_description = f.read() long_description = f.read()
# read version # read version
with open(os.path.join(here, "stable_speech", "__init__.py"), encoding="utf-8") as f: with open(os.path.join(here, "parler_tts", "__init__.py"), encoding="utf-8") as f:
for line in f: for line in f:
if line.startswith("__version__"): if line.startswith("__version__"):
version = line.split("=")[1].strip().strip('"') version = line.split("=")[1].strip().strip('"')
...@@ -50,7 +50,7 @@ with open(os.path.join(here, "stable_speech", "__init__.py"), encoding="utf-8") ...@@ -50,7 +50,7 @@ with open(os.path.join(here, "stable_speech", "__init__.py"), encoding="utf-8")
raise RuntimeError("Unable to find version string.") raise RuntimeError("Unable to find version string.")
setuptools.setup( setuptools.setup(
name="stable_speech", name="parler_tts",
version=version, version=version,
description="Toolkit for reproducing Stability AI's text-to-speech model.", description="Toolkit for reproducing Stability AI's text-to-speech model.",
long_description=long_description, long_description=long_description,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment