Commit 91542bfa authored by Yoach Lacombe's avatar Yoach Lacombe
Browse files

remove stable speech mentions

parent 315d3fca
# Stable Speech # Parler-TTS
Work in-progress reproduction of the text-to-speech (TTS) model from the paper [Natural language guidance of high-fidelity text-to-speech with synthetic annotations](https://www.text-description-to-speech.com) Work in-progress reproduction of the text-to-speech (TTS) model from the paper [Natural language guidance of high-fidelity text-to-speech with synthetic annotations](https://www.text-description-to-speech.com)
by Dan Lyth and Simon King, from Stability AI and Edinburgh University respectively. by Dan Lyth and Simon King, from Stability AI and Edinburgh University respectively.
......
from stable_speech import StableSpeechConfig, StableSpeechForCausalLM, StableSpeechForConditionalGeneration, StableSpeechDecoderConfig from parler_tts import ParlerTTSConfig, ParlerTTSForCausalLM, ParlerTTSForConditionalGeneration, ParlerTTSDecoderConfig
from transformers import T5Config, EncodecConfig from transformers import T5Config, EncodecConfig
from transformers import AutoConfig from transformers import AutoConfig
...@@ -12,7 +12,7 @@ encodec = AutoConfig.from_pretrained(encodec_version) ...@@ -12,7 +12,7 @@ encodec = AutoConfig.from_pretrained(encodec_version)
encodec_vocab_size = encodec.codebook_size encodec_vocab_size = encodec.codebook_size
decoder_config = StableSpeechDecoderConfig( decoder_config = ParlerTTSDecoderConfig(
vocab_size=encodec_vocab_size+1, vocab_size=encodec_vocab_size+1,
max_position_embeddings=2048, max_position_embeddings=2048,
num_hidden_layers=4, num_hidden_layers=4,
...@@ -34,13 +34,13 @@ decoder_config = StableSpeechDecoderConfig( ...@@ -34,13 +34,13 @@ decoder_config = StableSpeechDecoderConfig(
decoder = StableSpeechForCausalLM(decoder_config) decoder = ParlerTTSForCausalLM(decoder_config)
decoder.save_pretrained("/raid/yoach/tmp/artefacts/decoder/") decoder.save_pretrained("/raid/yoach/tmp/artefacts/decoder/")
model = StableSpeechForConditionalGeneration.from_sub_models_pretrained( model = ParlerTTSForConditionalGeneration.from_sub_models_pretrained(
text_encoder_pretrained_model_name_or_path=text_model, text_encoder_pretrained_model_name_or_path=text_model,
audio_encoder_pretrained_model_name_or_path=encodec_version, audio_encoder_pretrained_model_name_or_path=encodec_version,
decoder_pretrained_model_name_or_path="/raid/yoach/tmp/artefacts/decoder/", decoder_pretrained_model_name_or_path="/raid/yoach/tmp/artefacts/decoder/",
......
from stable_speech import StableSpeechConfig, StableSpeechForCausalLM, StableSpeechForConditionalGeneration, StableSpeechDecoderConfig from parler_tts import ParlerTTSConfig, ParlerTTSForCausalLM, ParlerTTSForConditionalGeneration, ParlerTTSDecoderConfig
from transformers import AutoConfig from transformers import AutoConfig
from transformers import AutoModel from transformers import AutoModel
from transformers import AutoConfig, AutoModel from transformers import AutoConfig, AutoModel
from stable_speech import DACConfig, DACModel from parler_tts import DACConfig, DACModel
AutoConfig.register("dac", DACConfig) AutoConfig.register("dac", DACConfig)
AutoModel.register(DACConfig, DACModel) AutoModel.register(DACConfig, DACModel)
...@@ -19,7 +19,7 @@ encodec = AutoConfig.from_pretrained(encodec_version) ...@@ -19,7 +19,7 @@ encodec = AutoConfig.from_pretrained(encodec_version)
encodec_vocab_size = encodec.codebook_size encodec_vocab_size = encodec.codebook_size
decoder_config = StableSpeechDecoderConfig( decoder_config = ParlerTTSDecoderConfig(
vocab_size=encodec_vocab_size+1, vocab_size=encodec_vocab_size+1,
max_position_embeddings=2048, max_position_embeddings=2048,
num_hidden_layers=4, num_hidden_layers=4,
...@@ -41,13 +41,13 @@ decoder_config = StableSpeechDecoderConfig( ...@@ -41,13 +41,13 @@ decoder_config = StableSpeechDecoderConfig(
decoder = StableSpeechForCausalLM(decoder_config) decoder = ParlerTTSForCausalLM(decoder_config)
decoder.save_pretrained("/raid/yoach/tmp/artefacts/decoder/") decoder.save_pretrained("/raid/yoach/tmp/artefacts/decoder/")
model = StableSpeechForConditionalGeneration.from_sub_models_pretrained( model = ParlerTTSForConditionalGeneration.from_sub_models_pretrained(
text_encoder_pretrained_model_name_or_path=text_model, text_encoder_pretrained_model_name_or_path=text_model,
audio_encoder_pretrained_model_name_or_path=encodec_version, audio_encoder_pretrained_model_name_or_path=encodec_version,
decoder_pretrained_model_name_or_path="/raid/yoach/tmp/artefacts/decoder/", decoder_pretrained_model_name_or_path="/raid/yoach/tmp/artefacts/decoder/",
......
from stable_speech import StableSpeechConfig, StableSpeechForCausalLM, StableSpeechForConditionalGeneration, StableSpeechDecoderConfig from parler_tts import ParlerTTSConfig, ParlerTTSForCausalLM, ParlerTTSForConditionalGeneration, ParlerTTSDecoderConfig
from transformers import T5Config, EncodecConfig from transformers import T5Config, EncodecConfig
from transformers import AutoConfig from transformers import AutoConfig
from transformers import AutoConfig, AutoModel from transformers import AutoConfig, AutoModel
from stable_speech import DACConfig, DACModel from parler_tts import DACConfig, DACModel
AutoConfig.register("dac", DACConfig) AutoConfig.register("dac", DACConfig)
AutoModel.register(DACConfig, DACModel) AutoModel.register(DACConfig, DACModel)
...@@ -20,7 +20,7 @@ encodec = AutoConfig.from_pretrained(encodec_version) ...@@ -20,7 +20,7 @@ encodec = AutoConfig.from_pretrained(encodec_version)
encodec_vocab_size = encodec.codebook_size encodec_vocab_size = encodec.codebook_size
decoder_config = StableSpeechDecoderConfig( decoder_config = ParlerTTSDecoderConfig(
vocab_size=encodec_vocab_size+1, vocab_size=encodec_vocab_size+1,
max_position_embeddings=3000, # 30 s = 2580 max_position_embeddings=3000, # 30 s = 2580
num_hidden_layers=12, num_hidden_layers=12,
...@@ -40,13 +40,13 @@ decoder_config = StableSpeechDecoderConfig( ...@@ -40,13 +40,13 @@ decoder_config = StableSpeechDecoderConfig(
) )
decoder = StableSpeechForCausalLM(decoder_config) decoder = ParlerTTSForCausalLM(decoder_config)
decoder.save_pretrained("/raid/yoach/tmp/artefacts/decoder/") decoder.save_pretrained("/raid/yoach/tmp/artefacts/decoder/")
model = StableSpeechForConditionalGeneration.from_sub_models_pretrained( model = ParlerTTSForConditionalGeneration.from_sub_models_pretrained(
text_encoder_pretrained_model_name_or_path=text_model, text_encoder_pretrained_model_name_or_path=text_model,
audio_encoder_pretrained_model_name_or_path=encodec_version, audio_encoder_pretrained_model_name_or_path=encodec_version,
decoder_pretrained_model_name_or_path="/raid/yoach/tmp/artefacts/decoder/", decoder_pretrained_model_name_or_path="/raid/yoach/tmp/artefacts/decoder/",
......
from stable_speech import StableSpeechConfig, StableSpeechForCausalLM, StableSpeechForConditionalGeneration, StableSpeechDecoderConfig from parler_tts import ParlerTTSConfig, ParlerTTSForCausalLM, ParlerTTSForConditionalGeneration, ParlerTTSDecoderConfig
from transformers import T5Config, EncodecConfig from transformers import T5Config, EncodecConfig
from transformers import AutoConfig from transformers import AutoConfig
from transformers import AutoConfig, AutoModel from transformers import AutoConfig, AutoModel
from stable_speech import DACConfig, DACModel from parler_tts import DACConfig, DACModel
AutoConfig.register("dac", DACConfig) AutoConfig.register("dac", DACConfig)
AutoModel.register(DACConfig, DACModel) AutoModel.register(DACConfig, DACModel)
...@@ -20,7 +20,7 @@ encodec = AutoConfig.from_pretrained(encodec_version) ...@@ -20,7 +20,7 @@ encodec = AutoConfig.from_pretrained(encodec_version)
encodec_vocab_size = encodec.codebook_size encodec_vocab_size = encodec.codebook_size
decoder_config = StableSpeechDecoderConfig( decoder_config = ParlerTTSDecoderConfig(
vocab_size=encodec_vocab_size+1, vocab_size=encodec_vocab_size+1,
max_position_embeddings=4096, # 30 s = 2580 max_position_embeddings=4096, # 30 s = 2580
num_hidden_layers=8, num_hidden_layers=8,
...@@ -40,13 +40,13 @@ decoder_config = StableSpeechDecoderConfig( ...@@ -40,13 +40,13 @@ decoder_config = StableSpeechDecoderConfig(
) )
decoder = StableSpeechForCausalLM(decoder_config) decoder = ParlerTTSForCausalLM(decoder_config)
decoder.save_pretrained("/raid/yoach/tmp/artefacts/decoder_small/") decoder.save_pretrained("/raid/yoach/tmp/artefacts/decoder_small/")
model = StableSpeechForConditionalGeneration.from_sub_models_pretrained( model = ParlerTTSForConditionalGeneration.from_sub_models_pretrained(
text_encoder_pretrained_model_name_or_path=text_model, text_encoder_pretrained_model_name_or_path=text_model,
audio_encoder_pretrained_model_name_or_path=encodec_version, audio_encoder_pretrained_model_name_or_path=encodec_version,
decoder_pretrained_model_name_or_path="/raid/yoach/tmp/artefacts/decoder_small/", decoder_pretrained_model_name_or_path="/raid/yoach/tmp/artefacts/decoder_small/",
......
from .configuration_stable_speech import StableSpeechConfig, StableSpeechDecoderConfig from .configuration_parler_tts import ParlerTTSConfig, ParlerTTSDecoderConfig
from .modeling_stable_speech import StableSpeechForCausalLM, StableSpeechForConditionalGeneration, apply_delay_pattern_mask, build_delay_pattern_mask from .modeling_parler_tts import ParlerTTSForCausalLM, ParlerTTSForConditionalGeneration, apply_delay_pattern_mask, build_delay_pattern_mask
from .dac_wrapper import DACConfig, DACModel from .dac_wrapper import DACConfig, DACModel
\ No newline at end of file
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" Stable Speech model configuration""" """ Parler-TTS model configuration"""
from transformers import AutoConfig, logging from transformers import AutoConfig, logging
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
...@@ -21,17 +21,17 @@ from transformers.configuration_utils import PretrainedConfig ...@@ -21,17 +21,17 @@ from transformers.configuration_utils import PretrainedConfig
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
MUSICGEN_PRETRAINED_CONFIG_ARCHIVE_MAP = { MUSICGEN_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"facebook/stable_speech-small": "https://huggingface.co/facebook/stable_speech-small/resolve/main/config.json", "facebook/parler_tts-small": "https://huggingface.co/facebook/parler_tts-small/resolve/main/config.json",
# See all StableSpeech models at https://huggingface.co/models?filter=stable_speech # See all ParlerTTS models at https://huggingface.co/models?filter=parler_tts
} }
class StableSpeechDecoderConfig(PretrainedConfig): class ParlerTTSDecoderConfig(PretrainedConfig):
r""" r"""
This is the configuration class to store the configuration of an [`StableSpeechDecoder`]. It is used to instantiate a This is the configuration class to store the configuration of an [`ParlerTTSDecoder`]. It is used to instantiate a
Stable Speech decoder according to the specified arguments, defining the model architecture. Instantiating a Parler-TTS decoder according to the specified arguments, defining the model architecture. Instantiating a
configuration with the defaults will yield a similar configuration to that of the Stable Speech configuration with the defaults will yield a similar configuration to that of the Parler-TTS
[facebook/stable_speech-small](https://huggingface.co/facebook/stable_speech-small) architecture. [facebook/parler_tts-small](https://huggingface.co/facebook/parler_tts-small) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information. documentation from [`PretrainedConfig`] for more information.
...@@ -39,8 +39,8 @@ class StableSpeechDecoderConfig(PretrainedConfig): ...@@ -39,8 +39,8 @@ class StableSpeechDecoderConfig(PretrainedConfig):
Args: Args:
vocab_size (`int`, *optional*, defaults to 2049): vocab_size (`int`, *optional*, defaults to 2049):
Vocabulary size of the StableSpeechDecoder model. Defines the number of different tokens that can be Vocabulary size of the ParlerTTSDecoder model. Defines the number of different tokens that can be
represented by the `inputs_ids` passed when calling [`StableSpeechDecoder`]. represented by the `inputs_ids` passed when calling [`ParlerTTSDecoder`].
hidden_size (`int`, *optional*, defaults to 1024): hidden_size (`int`, *optional*, defaults to 1024):
Dimensionality of the layers and the pooler layer. Dimensionality of the layers and the pooler layer.
num_hidden_layers (`int`, *optional*, defaults to 24): num_hidden_layers (`int`, *optional*, defaults to 24):
...@@ -76,7 +76,7 @@ class StableSpeechDecoderConfig(PretrainedConfig): ...@@ -76,7 +76,7 @@ class StableSpeechDecoderConfig(PretrainedConfig):
Whether input and output word embeddings should be tied. Whether input and output word embeddings should be tied.
""" """
model_type = "stable_speech_decoder" model_type = "parler_tts_decoder"
keys_to_ignore_at_inference = ["past_key_values"] keys_to_ignore_at_inference = ["past_key_values"]
def __init__( def __init__(
...@@ -127,10 +127,10 @@ class StableSpeechDecoderConfig(PretrainedConfig): ...@@ -127,10 +127,10 @@ class StableSpeechDecoderConfig(PretrainedConfig):
) )
class StableSpeechConfig(PretrainedConfig): class ParlerTTSConfig(PretrainedConfig):
r""" r"""
This is the configuration class to store the configuration of a [`StableSpeechModel`]. It is used to instantiate a This is the configuration class to store the configuration of a [`ParlerTTSModel`]. It is used to instantiate a
Stable Speech model according to the specified arguments, defining the text encoder, audio encoder and Stable Speech decoder Parler-TTS model according to the specified arguments, defining the text encoder, audio encoder and Parler-TTS decoder
configs. configs.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
...@@ -153,24 +153,24 @@ class StableSpeechConfig(PretrainedConfig): ...@@ -153,24 +153,24 @@ class StableSpeechConfig(PretrainedConfig):
```python ```python
>>> from transformers import ( >>> from transformers import (
... StableSpeechConfig, ... ParlerTTSConfig,
... StableSpeechDecoderConfig, ... ParlerTTSDecoderConfig,
... T5Config, ... T5Config,
... EncodecConfig, ... EncodecConfig,
... StableSpeechForConditionalGeneration, ... ParlerTTSForConditionalGeneration,
... ) ... )
>>> # Initializing text encoder, audio encoder, and decoder model configurations >>> # Initializing text encoder, audio encoder, and decoder model configurations
>>> text_encoder_config = T5Config() >>> text_encoder_config = T5Config()
>>> audio_encoder_config = EncodecConfig() >>> audio_encoder_config = EncodecConfig()
>>> decoder_config = StableSpeechDecoderConfig() >>> decoder_config = ParlerTTSDecoderConfig()
>>> configuration = StableSpeechConfig.from_sub_models_config( >>> configuration = ParlerTTSConfig.from_sub_models_config(
... text_encoder_config, audio_encoder_config, decoder_config ... text_encoder_config, audio_encoder_config, decoder_config
... ) ... )
>>> # Initializing a StableSpeechForConditionalGeneration (with random weights) from the facebook/stable_speech-small style configuration >>> # Initializing a ParlerTTSForConditionalGeneration (with random weights) from the facebook/parler_tts-small style configuration
>>> model = StableSpeechForConditionalGeneration(configuration) >>> model = ParlerTTSForConditionalGeneration(configuration)
>>> # Accessing the model configuration >>> # Accessing the model configuration
>>> configuration = model.config >>> configuration = model.config
...@@ -179,14 +179,14 @@ class StableSpeechConfig(PretrainedConfig): ...@@ -179,14 +179,14 @@ class StableSpeechConfig(PretrainedConfig):
>>> config_decoder = model.config.decoder >>> config_decoder = model.config.decoder
>>> # Saving the model, including its configuration >>> # Saving the model, including its configuration
>>> model.save_pretrained("stable_speech-model") >>> model.save_pretrained("parler_tts-model")
>>> # loading model and config from pretrained folder >>> # loading model and config from pretrained folder
>>> stable_speech_config = StableSpeechConfig.from_pretrained("stable_speech-model") >>> parler_tts_config = ParlerTTSConfig.from_pretrained("parler_tts-model")
>>> model = StableSpeechForConditionalGeneration.from_pretrained("stable_speech-model", config=stable_speech_config) >>> model = ParlerTTSForConditionalGeneration.from_pretrained("parler_tts-model", config=parler_tts_config)
```""" ```"""
model_type = "stable_speech" model_type = "parler_tts"
is_composition = True is_composition = True
def __init__(self, vocab_size=1024, **kwargs): def __init__(self, vocab_size=1024, **kwargs):
...@@ -205,7 +205,7 @@ class StableSpeechConfig(PretrainedConfig): ...@@ -205,7 +205,7 @@ class StableSpeechConfig(PretrainedConfig):
self.vocab_size = vocab_size self.vocab_size = vocab_size
self.text_encoder = AutoConfig.for_model(text_encoder_model_type, **text_encoder_config) self.text_encoder = AutoConfig.for_model(text_encoder_model_type, **text_encoder_config)
self.audio_encoder = AutoConfig.for_model(audio_encoder_model_type, **audio_encoder_config) self.audio_encoder = AutoConfig.for_model(audio_encoder_model_type, **audio_encoder_config)
self.decoder = StableSpeechDecoderConfig(**decoder_config) self.decoder = ParlerTTSDecoderConfig(**decoder_config)
self.is_encoder_decoder = True self.is_encoder_decoder = True
@classmethod @classmethod
...@@ -213,15 +213,15 @@ class StableSpeechConfig(PretrainedConfig): ...@@ -213,15 +213,15 @@ class StableSpeechConfig(PretrainedConfig):
cls, cls,
text_encoder_config: PretrainedConfig, text_encoder_config: PretrainedConfig,
audio_encoder_config: PretrainedConfig, audio_encoder_config: PretrainedConfig,
decoder_config: StableSpeechDecoderConfig, decoder_config: ParlerTTSDecoderConfig,
**kwargs, **kwargs,
): ):
r""" r"""
Instantiate a [`StableSpeechConfig`] (or a derived class) from text encoder, audio encoder and decoder Instantiate a [`ParlerTTSConfig`] (or a derived class) from text encoder, audio encoder and decoder
configurations. configurations.
Returns: Returns:
[`StableSpeechConfig`]: An instance of a configuration object [`ParlerTTSConfig`]: An instance of a configuration object
""" """
return cls( return cls(
......
This diff is collapsed.
...@@ -4,7 +4,7 @@ import dac ...@@ -4,7 +4,7 @@ import dac
model_path = dac.utils.download(model_type="44khz") model_path = dac.utils.download(model_type="44khz")
model = dac.DAC.load(model_path) model = dac.DAC.load(model_path)
from stable_speech import DACConfig, DACModel from parler_tts import DACConfig, DACModel
hf_dac = DACModel(DACConfig()) hf_dac = DACModel(DACConfig())
hf_dac.model.load_state_dict(model.state_dict()) hf_dac.model.load_state_dict(model.state_dict())
......
...@@ -63,7 +63,7 @@ from transformers.utils import check_min_version, send_example_telemetry ...@@ -63,7 +63,7 @@ from transformers.utils import check_min_version, send_example_telemetry
from transformers.utils.versions import require_version from transformers.utils.versions import require_version
from transformers.integrations import is_wandb_available from transformers.integrations import is_wandb_available
from transformers import AutoConfig, AutoModel from transformers import AutoConfig, AutoModel
from stable_speech import DACConfig, DACModel from parler_tts import DACConfig, DACModel
from transformers.modeling_outputs import BaseModelOutput from transformers.modeling_outputs import BaseModelOutput
AutoConfig.register("dac", DACConfig) AutoConfig.register("dac", DACConfig)
AutoModel.register(DACConfig, DACModel) AutoModel.register(DACConfig, DACModel)
...@@ -73,7 +73,7 @@ from accelerate import Accelerator ...@@ -73,7 +73,7 @@ from accelerate import Accelerator
from accelerate.utils import set_seed, AutocastKwargs, InitProcessGroupKwargs, TorchDynamoPlugin from accelerate.utils import set_seed, AutocastKwargs, InitProcessGroupKwargs, TorchDynamoPlugin
from accelerate.utils.memory import release_memory from accelerate.utils.memory import release_memory
from stable_speech import StableSpeechForConditionalGeneration, StableSpeechConfig, apply_delay_pattern_mask, build_delay_pattern_mask from parler_tts import ParlerTTSForConditionalGeneration, ParlerTTSConfig, apply_delay_pattern_mask, build_delay_pattern_mask
if is_wandb_available(): if is_wandb_available():
from wandb import Audio from wandb import Audio
...@@ -455,7 +455,7 @@ class DataTrainingArguments: ...@@ -455,7 +455,7 @@ class DataTrainingArguments:
} }
) )
wandb_project: str = field( wandb_project: str = field(
default="stable-speech", default="parler-speech",
metadata={"help": "The name of the wandb project."}, metadata={"help": "The name of the wandb project."},
) )
save_to_disk: str = field( save_to_disk: str = field(
...@@ -480,7 +480,7 @@ class DataTrainingArguments: ...@@ -480,7 +480,7 @@ class DataTrainingArguments:
) )
@dataclass @dataclass
class StableSpeechTrainingArguments(Seq2SeqTrainingArguments): class ParlerTTSTrainingArguments(Seq2SeqTrainingArguments):
dtype: Optional[str] = field( dtype: Optional[str] = field(
default="float32", default="float32",
metadata={ metadata={
...@@ -525,7 +525,7 @@ class DataCollatorEncodecWithPadding: ...@@ -525,7 +525,7 @@ class DataCollatorEncodecWithPadding:
@dataclass @dataclass
class DataCollatorStableSpeechWithPadding: class DataCollatorParlerTTSWithPadding:
""" """
Data collator that will dynamically pad the inputs received. Data collator that will dynamically pad the inputs received.
Args: Args:
...@@ -716,14 +716,14 @@ def load_multiple_datasets( ...@@ -716,14 +716,14 @@ def load_multiple_datasets(
# TODO(YL): I forgot to create unique ids for MLS english. # TODO(YL): I forgot to create unique ids for MLS english.
# To iterate faster, I bypass the original id check and do another one. - Done once because assuming it won't change next time # To iterate faster, I bypass the original id check and do another one. - Done once because assuming it won't change next time
# if dataset_dict["name"] == "stable-speech/mls_eng_10k": # if dataset_dict["name"] == "parler-tts/mls_eng_10k":
# def concat_ids(book_id, speaker_id, begin_time): # def concat_ids(book_id, speaker_id, begin_time):
# return {"id": f"{book_id}_{speaker_id}_{str(begin_time).replace('.', '_')}"} # return {"id": f"{book_id}_{speaker_id}_{str(begin_time).replace('.', '_')}"}
# dataset = dataset.map(concat_ids, input_columns=["book_id", "speaker_id", "begin_time"], num_proc=24) # dataset = dataset.map(concat_ids, input_columns=["book_id", "speaker_id", "begin_time"], num_proc=24)
# metadata_dataset = metadata_dataset.map(concat_ids, input_columns=["book_id", "speaker_id", "begin_time"], num_proc=24) # metadata_dataset = metadata_dataset.map(concat_ids, input_columns=["book_id", "speaker_id", "begin_time"], num_proc=24)
# metadata_dataset = metadata_dataset.rename_column(id_column_name, f"metadata_{id_column_name}") # metadata_dataset = metadata_dataset.rename_column(id_column_name, f"metadata_{id_column_name}")
if dataset_dict["name"] != "stable-speech/mls_eng_10k": if dataset_dict["name"] != "parler-tts/mls_eng_10k":
if id_column_name is not None and id_column_name not in dataset.column_names: if id_column_name is not None and id_column_name not in dataset.column_names:
raise ValueError( raise ValueError(
f"id_column_name={id_column_name} but has not been found in the dataset columns" f"id_column_name={id_column_name} but has not been found in the dataset columns"
...@@ -751,7 +751,7 @@ def load_multiple_datasets( ...@@ -751,7 +751,7 @@ def load_multiple_datasets(
dataset = concatenate_datasets([dataset, metadata_dataset], axis=1) dataset = concatenate_datasets([dataset, metadata_dataset], axis=1)
if id_column_name is not None and dataset_dict["name"] != "stable-speech/mls_eng_10k": if id_column_name is not None and dataset_dict["name"] != "parler-tts/mls_eng_10k":
if len(dataset.filter(lambda id1, id2: id1!=id2, input_columns=[id_column_name, f"metadata_{id_column_name}"])) != 0: if len(dataset.filter(lambda id1, id2: id1!=id2, input_columns=[id_column_name, f"metadata_{id_column_name}"])) != 0:
raise ValueError(f"Concatenate didn't work. Some ids don't correspond on dataset {dataset_dict['name']}") raise ValueError(f"Concatenate didn't work. Some ids don't correspond on dataset {dataset_dict['name']}")
...@@ -785,7 +785,7 @@ def main(): ...@@ -785,7 +785,7 @@ def main():
# or by passing the --help flag to this script. # or by passing the --help flag to this script.
# We now keep distinct sets of args, for a cleaner separation of concerns. # We now keep distinct sets of args, for a cleaner separation of concerns.
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, StableSpeechTrainingArguments)) parser = HfArgumentParser((ModelArguments, DataTrainingArguments, ParlerTTSTrainingArguments))
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
# If we pass only one argument to the script and it's the path to a json file, # If we pass only one argument to the script and it's the path to a json file,
# let's parse it to get our arguments. # let's parse it to get our arguments.
...@@ -795,7 +795,7 @@ def main(): ...@@ -795,7 +795,7 @@ def main():
# Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
# information sent is the one passed as arguments along with your Python/PyTorch versions. # information sent is the one passed as arguments along with your Python/PyTorch versions.
send_example_telemetry("run_stable_speech", model_args, data_args) send_example_telemetry("run_parler_tts", model_args, data_args)
if training_args.dtype == "float16": if training_args.dtype == "float16":
mixed_precision = "fp16" mixed_precision = "fp16"
...@@ -996,7 +996,7 @@ def main(): ...@@ -996,7 +996,7 @@ def main():
# 3. Next, let's load the config. # 3. Next, let's load the config.
# TODO(YL): add the option to create the config from scratch # TODO(YL): add the option to create the config from scratch
config = StableSpeechConfig.from_pretrained( config = ParlerTTSConfig.from_pretrained(
model_args.model_name_or_path, model_args.model_name_or_path,
cache_dir=model_args.cache_dir, cache_dir=model_args.cache_dir,
token=data_args.token, token=data_args.token,
...@@ -1011,7 +1011,7 @@ def main(): ...@@ -1011,7 +1011,7 @@ def main():
}) })
# create model + TODO(YL): not from_pretrained probably # create model + TODO(YL): not from_pretrained probably
model = StableSpeechForConditionalGeneration.from_pretrained( model = ParlerTTSForConditionalGeneration.from_pretrained(
model_args.model_name_or_path, model_args.model_name_or_path,
cache_dir=model_args.cache_dir, cache_dir=model_args.cache_dir,
config=config, config=config,
...@@ -1353,7 +1353,7 @@ def main(): ...@@ -1353,7 +1353,7 @@ def main():
) )
# Instantiate custom data collator # Instantiate custom data collator
data_collator = DataCollatorStableSpeechWithPadding( data_collator = DataCollatorParlerTTSWithPadding(
audio_feature_extractor=feature_extractor, feature_extractor_input_name=feature_extractor_input_name, prompt_tokenizer=prompt_tokenizer, description_tokenizer=description_tokenizer, pad_to_multiple_of=data_args.pad_to_multiple_of, audio_feature_extractor=feature_extractor, feature_extractor_input_name=feature_extractor_input_name, prompt_tokenizer=prompt_tokenizer, description_tokenizer=description_tokenizer, pad_to_multiple_of=data_args.pad_to_multiple_of,
padding=padding, prompt_max_length=data_args.max_prompt_token_length, description_max_length=data_args.max_description_token_length, audio_max_length = audio_max_length padding=padding, prompt_max_length=data_args.max_prompt_token_length, description_max_length=data_args.max_description_token_length, audio_max_length = audio_max_length
) )
......
...@@ -41,7 +41,7 @@ with open(os.path.join(here, "README.md"), encoding="utf-8") as f: ...@@ -41,7 +41,7 @@ with open(os.path.join(here, "README.md"), encoding="utf-8") as f:
long_description = f.read() long_description = f.read()
# read version # read version
with open(os.path.join(here, "stable_speech", "__init__.py"), encoding="utf-8") as f: with open(os.path.join(here, "parler_tts", "__init__.py"), encoding="utf-8") as f:
for line in f: for line in f:
if line.startswith("__version__"): if line.startswith("__version__"):
version = line.split("=")[1].strip().strip('"') version = line.split("=")[1].strip().strip('"')
...@@ -50,7 +50,7 @@ with open(os.path.join(here, "stable_speech", "__init__.py"), encoding="utf-8") ...@@ -50,7 +50,7 @@ with open(os.path.join(here, "stable_speech", "__init__.py"), encoding="utf-8")
raise RuntimeError("Unable to find version string.") raise RuntimeError("Unable to find version string.")
setuptools.setup( setuptools.setup(
name="stable_speech", name="parler_tts",
version=version, version=version,
description="Toolkit for reproducing Stability AI's text-to-speech model.", description="Toolkit for reproducing Stability AI's text-to-speech model.",
long_description=long_description, long_description=long_description,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment