Unverified Commit d83ff5ee authored by Connor Henderson's avatar Connor Henderson Committed by GitHub
Browse files

Add FastSpeech2Conformer (#23439)

* start - docs, SpeechT5 copy and rename

* add relevant code from FastSpeech2 draft, have tests pass

* make it an actual conformer, demo ex.

* matching inference with original repo, includes debug code

* refactor nn.Sequentials, start more desc. var names

* more renaming

* more renaming

* vocoder scratchwork

* matching vocoder outputs

* hifigan vocoder conversion script

* convert model script, rename some config vars

* replace postnet with speecht5's implementation

* passing common tests, file cleanup

* expand testing, add output hidden states and attention

* tokenizer + passing tokenizer tests

* variety of updates and tests

* g2p_en pckg setup

* import structure edits

* docstrings and cleanup

* repo consistency

* deps

* small cleanup

* forward signature param order

* address comments except for masks and labels

* address comments on attention_mask and labels

* address second round of comments

* remove old unneeded line

* address comments part 1

* address comments pt 2

* rename auto mapping

* fixes for failing tests

* address comments part 3 (bart-like, train loss)

* make style

* pass config where possible

* add forward method + tests to WithHifiGan model

* make style

* address arg passing and generate_speech comments

* address Arthur comments

* address Arthur comments pt2

* lint  changes

* Sanchit comment

* add g2p-en to doctest deps

* move up self.encoder

* onnx compatible tensor method

* fix is symbolic

* fix paper url

* move models to espnet org

* make style

* make fix-copies

* update docstring

* Arthur comments

* update docstring w/ new updates

* add model architecture images

* header size

* md wording update

* make style
parent 6eba901d
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert FastSpeech2Conformer checkpoint."""
import argparse
import json
import re
from pathlib import Path
from tempfile import TemporaryDirectory
import torch
import yaml
from transformers import (
FastSpeech2ConformerConfig,
FastSpeech2ConformerModel,
FastSpeech2ConformerTokenizer,
logging,
)
logging.set_verbosity_info()
logger = logging.get_logger("transformers.models.FastSpeech2Conformer")
CONFIG_MAPPING = {
"adim": "hidden_size",
"aheads": "num_attention_heads",
"conformer_dec_kernel_size": "decoder_kernel_size",
"conformer_enc_kernel_size": "encoder_kernel_size",
"decoder_normalize_before": "decoder_normalize_before",
"dlayers": "decoder_layers",
"dunits": "decoder_linear_units",
"duration_predictor_chans": "duration_predictor_channels",
"duration_predictor_kernel_size": "duration_predictor_kernel_size",
"duration_predictor_layers": "duration_predictor_layers",
"elayers": "encoder_layers",
"encoder_normalize_before": "encoder_normalize_before",
"energy_embed_dropout": "energy_embed_dropout",
"energy_embed_kernel_size": "energy_embed_kernel_size",
"energy_predictor_chans": "energy_predictor_channels",
"energy_predictor_dropout": "energy_predictor_dropout",
"energy_predictor_kernel_size": "energy_predictor_kernel_size",
"energy_predictor_layers": "energy_predictor_layers",
"eunits": "encoder_linear_units",
"pitch_embed_dropout": "pitch_embed_dropout",
"pitch_embed_kernel_size": "pitch_embed_kernel_size",
"pitch_predictor_chans": "pitch_predictor_channels",
"pitch_predictor_dropout": "pitch_predictor_dropout",
"pitch_predictor_kernel_size": "pitch_predictor_kernel_size",
"pitch_predictor_layers": "pitch_predictor_layers",
"positionwise_conv_kernel_size": "positionwise_conv_kernel_size",
"postnet_chans": "speech_decoder_postnet_units",
"postnet_filts": "speech_decoder_postnet_kernel",
"postnet_layers": "speech_decoder_postnet_layers",
"reduction_factor": "reduction_factor",
"stop_gradient_from_energy_predictor": "stop_gradient_from_energy_predictor",
"stop_gradient_from_pitch_predictor": "stop_gradient_from_pitch_predictor",
"transformer_dec_attn_dropout_rate": "decoder_attention_dropout_rate",
"transformer_dec_dropout_rate": "decoder_dropout_rate",
"transformer_dec_positional_dropout_rate": "decoder_positional_dropout_rate",
"transformer_enc_attn_dropout_rate": "encoder_attention_dropout_rate",
"transformer_enc_dropout_rate": "encoder_dropout_rate",
"transformer_enc_positional_dropout_rate": "encoder_positional_dropout_rate",
"use_cnn_in_conformer": "use_cnn_in_conformer",
"use_macaron_style_in_conformer": "use_macaron_style_in_conformer",
"use_masking": "use_masking",
"use_weighted_masking": "use_weighted_masking",
"idim": "input_dim",
"odim": "num_mel_bins",
"spk_embed_dim": "speaker_embed_dim",
"langs": "num_languages",
"spks": "num_speakers",
}
def remap_model_yaml_config(yaml_config_path):
with Path(yaml_config_path).open("r", encoding="utf-8") as f:
args = yaml.safe_load(f)
args = argparse.Namespace(**args)
remapped_config = {}
model_params = args.tts_conf["text2mel_params"]
# espnet_config_key -> hf_config_key, any keys not included are ignored
for espnet_config_key, hf_config_key in CONFIG_MAPPING.items():
if espnet_config_key in model_params:
remapped_config[hf_config_key] = model_params[espnet_config_key]
return remapped_config, args.g2p, args.token_list
def convert_espnet_state_dict_to_hf(state_dict):
new_state_dict = {}
for key in state_dict:
if "tts.generator.text2mel." in key:
new_key = key.replace("tts.generator.text2mel.", "")
if "postnet" in key:
new_key = new_key.replace("postnet.postnet", "speech_decoder_postnet.layers")
new_key = new_key.replace(".0.weight", ".conv.weight")
new_key = new_key.replace(".1.weight", ".batch_norm.weight")
new_key = new_key.replace(".1.bias", ".batch_norm.bias")
new_key = new_key.replace(".1.running_mean", ".batch_norm.running_mean")
new_key = new_key.replace(".1.running_var", ".batch_norm.running_var")
new_key = new_key.replace(".1.num_batches_tracked", ".batch_norm.num_batches_tracked")
if "feat_out" in key:
if "weight" in key:
new_key = "speech_decoder_postnet.feat_out.weight"
if "bias" in key:
new_key = "speech_decoder_postnet.feat_out.bias"
if "encoder.embed.0.weight" in key:
new_key = new_key.replace("0.", "")
if "w_1" in key:
new_key = new_key.replace("w_1", "conv1")
if "w_2" in key:
new_key = new_key.replace("w_2", "conv2")
if "predictor.conv" in key:
new_key = new_key.replace(".conv", ".conv_layers")
pattern = r"(\d)\.(\d)"
replacement = (
r"\1.conv" if ("2.weight" not in new_key) and ("2.bias" not in new_key) else r"\1.layer_norm"
)
new_key = re.sub(pattern, replacement, new_key)
if "pitch_embed" in key or "energy_embed" in key:
new_key = new_key.replace("0", "conv")
if "encoders" in key:
new_key = new_key.replace("encoders", "conformer_layers")
new_key = new_key.replace("norm_final", "final_layer_norm")
new_key = new_key.replace("norm_mha", "self_attn_layer_norm")
new_key = new_key.replace("norm_ff_macaron", "ff_macaron_layer_norm")
new_key = new_key.replace("norm_ff", "ff_layer_norm")
new_key = new_key.replace("norm_conv", "conv_layer_norm")
if "lid_emb" in key:
new_key = new_key.replace("lid_emb", "language_id_embedding")
if "sid_emb" in key:
new_key = new_key.replace("sid_emb", "speaker_id_embedding")
new_state_dict[new_key] = state_dict[key]
return new_state_dict
@torch.no_grad()
def convert_FastSpeech2ConformerModel_checkpoint(
checkpoint_path,
yaml_config_path,
pytorch_dump_folder_path,
repo_id=None,
):
model_params, tokenizer_name, vocab = remap_model_yaml_config(yaml_config_path)
config = FastSpeech2ConformerConfig(**model_params)
# Prepare the model
model = FastSpeech2ConformerModel(config)
espnet_checkpoint = torch.load(checkpoint_path)
hf_compatible_state_dict = convert_espnet_state_dict_to_hf(espnet_checkpoint)
model.load_state_dict(hf_compatible_state_dict)
model.save_pretrained(pytorch_dump_folder_path)
# Prepare the tokenizer
with TemporaryDirectory() as tempdir:
vocab = {token: id for id, token in enumerate(vocab)}
vocab_file = Path(tempdir) / "vocab.json"
with open(vocab_file, "w") as f:
json.dump(vocab, f)
should_strip_spaces = "no_space" in tokenizer_name
tokenizer = FastSpeech2ConformerTokenizer(str(vocab_file), should_strip_spaces=should_strip_spaces)
tokenizer.save_pretrained(pytorch_dump_folder_path)
if repo_id:
print("Pushing to the hub...")
model.push_to_hub(repo_id)
tokenizer.push_to_hub(repo_id)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--checkpoint_path", required=True, default=None, type=str, help="Path to original checkpoint")
parser.add_argument(
"--yaml_config_path", required=True, default=None, type=str, help="Path to config.yaml of model to convert"
)
parser.add_argument(
"--pytorch_dump_folder_path", required=True, default=None, type=str, help="Path to the output PyTorch model."
)
parser.add_argument(
"--push_to_hub", default=None, type=str, help="Where to upload the converted model on the 🤗 hub."
)
args = parser.parse_args()
convert_FastSpeech2ConformerModel_checkpoint(
args.checkpoint_path,
args.yaml_config_path,
args.pytorch_dump_folder_path,
args.push_to_hub,
)
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert FastSpeech2Conformer HiFi-GAN checkpoint."""
import argparse
from pathlib import Path
import torch
import yaml
from transformers import FastSpeech2ConformerHifiGan, FastSpeech2ConformerHifiGanConfig, logging
logging.set_verbosity_info()
logger = logging.get_logger("transformers.models.FastSpeech2Conformer")
def load_weights(checkpoint, hf_model, config):
vocoder_key_prefix = "tts.generator.vocoder."
checkpoint = {k.replace(vocoder_key_prefix, ""): v for k, v in checkpoint.items() if vocoder_key_prefix in k}
hf_model.apply_weight_norm()
hf_model.conv_pre.weight_g.data = checkpoint["input_conv.weight_g"]
hf_model.conv_pre.weight_v.data = checkpoint["input_conv.weight_v"]
hf_model.conv_pre.bias.data = checkpoint["input_conv.bias"]
for i in range(len(config.upsample_rates)):
hf_model.upsampler[i].weight_g.data = checkpoint[f"upsamples.{i}.1.weight_g"]
hf_model.upsampler[i].weight_v.data = checkpoint[f"upsamples.{i}.1.weight_v"]
hf_model.upsampler[i].bias.data = checkpoint[f"upsamples.{i}.1.bias"]
for i in range(len(config.upsample_rates) * len(config.resblock_kernel_sizes)):
for j in range(len(config.resblock_dilation_sizes)):
hf_model.resblocks[i].convs1[j].weight_g.data = checkpoint[f"blocks.{i}.convs1.{j}.1.weight_g"]
hf_model.resblocks[i].convs1[j].weight_v.data = checkpoint[f"blocks.{i}.convs1.{j}.1.weight_v"]
hf_model.resblocks[i].convs1[j].bias.data = checkpoint[f"blocks.{i}.convs1.{j}.1.bias"]
hf_model.resblocks[i].convs2[j].weight_g.data = checkpoint[f"blocks.{i}.convs2.{j}.1.weight_g"]
hf_model.resblocks[i].convs2[j].weight_v.data = checkpoint[f"blocks.{i}.convs2.{j}.1.weight_v"]
hf_model.resblocks[i].convs2[j].bias.data = checkpoint[f"blocks.{i}.convs2.{j}.1.bias"]
hf_model.conv_post.weight_g.data = checkpoint["output_conv.1.weight_g"]
hf_model.conv_post.weight_v.data = checkpoint["output_conv.1.weight_v"]
hf_model.conv_post.bias.data = checkpoint["output_conv.1.bias"]
hf_model.remove_weight_norm()
def remap_hifigan_yaml_config(yaml_config_path):
with Path(yaml_config_path).open("r", encoding="utf-8") as f:
args = yaml.safe_load(f)
args = argparse.Namespace(**args)
vocoder_type = args.tts_conf["vocoder_type"]
if vocoder_type != "hifigan_generator":
raise TypeError(f"Vocoder config must be for `hifigan_generator`, but got {vocoder_type}")
remapped_dict = {}
vocoder_params = args.tts_conf["vocoder_params"]
# espnet_config_key -> hf_config_key
key_mappings = {
"channels": "upsample_initial_channel",
"in_channels": "model_in_dim",
"resblock_dilations": "resblock_dilation_sizes",
"resblock_kernel_sizes": "resblock_kernel_sizes",
"upsample_kernel_sizes": "upsample_kernel_sizes",
"upsample_scales": "upsample_rates",
}
for espnet_config_key, hf_config_key in key_mappings.items():
remapped_dict[hf_config_key] = vocoder_params[espnet_config_key]
remapped_dict["sampling_rate"] = args.tts_conf["sampling_rate"]
remapped_dict["normalize_before"] = False
remapped_dict["leaky_relu_slope"] = vocoder_params["nonlinear_activation_params"]["negative_slope"]
return remapped_dict
@torch.no_grad()
def convert_hifigan_checkpoint(
checkpoint_path,
pytorch_dump_folder_path,
yaml_config_path=None,
repo_id=None,
):
if yaml_config_path is not None:
config_kwargs = remap_hifigan_yaml_config(yaml_config_path)
config = FastSpeech2ConformerHifiGanConfig(**config_kwargs)
else:
config = FastSpeech2ConformerHifiGanConfig()
model = FastSpeech2ConformerHifiGan(config)
orig_checkpoint = torch.load(checkpoint_path)
load_weights(orig_checkpoint, model, config)
model.save_pretrained(pytorch_dump_folder_path)
if repo_id:
print("Pushing to the hub...")
model.push_to_hub(repo_id)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--checkpoint_path", required=True, default=None, type=str, help="Path to original checkpoint")
parser.add_argument("--yaml_config_path", default=None, type=str, help="Path to config.yaml of model to convert")
parser.add_argument(
"--pytorch_dump_folder_path", required=True, default=None, type=str, help="Path to the output PyTorch model."
)
parser.add_argument(
"--push_to_hub", default=None, type=str, help="Where to upload the converted model on the 🤗 hub."
)
args = parser.parse_args()
convert_hifigan_checkpoint(
args.checkpoint_path,
args.pytorch_dump_folder_path,
args.yaml_config_path,
args.push_to_hub,
)
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert FastSpeech2Conformer checkpoint."""
import argparse
import torch
from transformers import (
FastSpeech2ConformerConfig,
FastSpeech2ConformerHifiGan,
FastSpeech2ConformerHifiGanConfig,
FastSpeech2ConformerModel,
FastSpeech2ConformerWithHifiGan,
FastSpeech2ConformerWithHifiGanConfig,
logging,
)
from .convert_fastspeech2_conformer_original_pytorch_checkpoint_to_pytorch import (
convert_espnet_state_dict_to_hf,
remap_model_yaml_config,
)
from .convert_hifigan import load_weights, remap_hifigan_yaml_config
logging.set_verbosity_info()
logger = logging.get_logger("transformers.models.FastSpeech2Conformer")
def convert_FastSpeech2ConformerWithHifiGan_checkpoint(
checkpoint_path,
yaml_config_path,
pytorch_dump_folder_path,
repo_id=None,
):
# Prepare the model
model_params, *_ = remap_model_yaml_config(yaml_config_path)
model_config = FastSpeech2ConformerConfig(**model_params)
model = FastSpeech2ConformerModel(model_config)
espnet_checkpoint = torch.load(checkpoint_path)
hf_compatible_state_dict = convert_espnet_state_dict_to_hf(espnet_checkpoint)
model.load_state_dict(hf_compatible_state_dict)
# Prepare the vocoder
config_kwargs = remap_hifigan_yaml_config(yaml_config_path)
vocoder_config = FastSpeech2ConformerHifiGanConfig(**config_kwargs)
vocoder = FastSpeech2ConformerHifiGan(vocoder_config)
load_weights(espnet_checkpoint, vocoder, vocoder_config)
# Prepare the model + vocoder
config = FastSpeech2ConformerWithHifiGanConfig.from_sub_model_configs(model_config, vocoder_config)
with_hifigan_model = FastSpeech2ConformerWithHifiGan(config)
with_hifigan_model.model = model
with_hifigan_model.vocoder = vocoder
with_hifigan_model.save_pretrained(pytorch_dump_folder_path)
if repo_id:
print("Pushing to the hub...")
with_hifigan_model.push_to_hub(repo_id)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--checkpoint_path", required=True, default=None, type=str, help="Path to original checkpoint")
parser.add_argument(
"--yaml_config_path", required=True, default=None, type=str, help="Path to config.yaml of model to convert"
)
parser.add_argument(
"--pytorch_dump_folder_path",
required=True,
default=None,
type=str,
help="Path to the output `FastSpeech2ConformerModel` PyTorch model.",
)
parser.add_argument(
"--push_to_hub", default=None, type=str, help="Where to upload the converted model on the 🤗 hub."
)
args = parser.parse_args()
convert_FastSpeech2ConformerWithHifiGan_checkpoint(
args.checkpoint_path,
args.yaml_config_path,
args.pytorch_dump_folder_path,
args.push_to_hub,
)
# coding=utf-8
# Copyright 2023 The HuggingFace Team and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tokenization classes for FastSpeech2Conformer."""
import json
import os
from typing import Optional, Tuple
import regex
from ...tokenization_utils import PreTrainedTokenizer
from ...utils import logging, requires_backends
logger = logging.get_logger(__name__)
VOCAB_FILES_NAMES = {"vocab_file": "vocab.json"}
PRETRAINED_VOCAB_FILES_MAP = {
"vocab_file": {
"espnet/fastspeech2_conformer": "https://huggingface.co/espnet/fastspeech2_conformer/raw/main/vocab.json",
},
}
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
# Set to somewhat arbitrary large number as the model input
# isn't constrained by the relative positional encoding
"espnet/fastspeech2_conformer": 4096,
}
class FastSpeech2ConformerTokenizer(PreTrainedTokenizer):
"""
Construct a FastSpeech2Conformer tokenizer.
Args:
vocab_file (`str`):
Path to the vocabulary file.
bos_token (`str`, *optional*, defaults to `"<sos/eos>"`):
The begin of sequence token. Note that for FastSpeech2, it is the same as the `eos_token`.
eos_token (`str`, *optional*, defaults to `"<sos/eos>"`):
The end of sequence token. Note that for FastSpeech2, it is the same as the `bos_token`.
pad_token (`str`, *optional*, defaults to `"<blank>"`):
The token used for padding, for example when batching sequences of different lengths.
unk_token (`str`, *optional*, defaults to `"<unk>"`):
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
token instead.
should_strip_spaces (`bool`, *optional*, defaults to `False`):
Whether or not to strip the spaces from the list of tokens.
"""
vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
model_input_names = ["input_ids", "attention_mask"]
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
def __init__(
self,
vocab_file,
bos_token="<sos/eos>",
eos_token="<sos/eos>",
pad_token="<blank>",
unk_token="<unk>",
should_strip_spaces=False,
**kwargs,
):
requires_backends(self, "g2p_en")
with open(vocab_file, encoding="utf-8") as vocab_handle:
self.encoder = json.load(vocab_handle)
import g2p_en
self.g2p = g2p_en.G2p()
self.decoder = {v: k for k, v in self.encoder.items()}
super().__init__(
bos_token=bos_token,
eos_token=eos_token,
unk_token=unk_token,
pad_token=pad_token,
should_strip_spaces=should_strip_spaces,
**kwargs,
)
self.should_strip_spaces = should_strip_spaces
@property
def vocab_size(self):
return len(self.decoder)
def get_vocab(self):
"Returns vocab as a dict"
return dict(self.encoder, **self.added_tokens_encoder)
def prepare_for_tokenization(self, text, is_split_into_words=False, **kwargs):
# expand symbols
text = regex.sub(";", ",", text)
text = regex.sub(":", ",", text)
text = regex.sub("-", " ", text)
text = regex.sub("&", "and", text)
# strip unnecessary symbols
text = regex.sub(r"[\(\)\[\]\<\>\"]+", "", text)
# strip whitespaces
text = regex.sub(r"\s+", " ", text)
text = text.upper()
return text, kwargs
def _tokenize(self, text):
"""Returns a tokenized string."""
# phonemize
tokens = self.g2p(text)
if self.should_strip_spaces:
tokens = list(filter(lambda s: s != " ", tokens))
tokens.append(self.eos_token)
return tokens
def _convert_token_to_id(self, token):
"""Converts a token (str) in an id using the vocab."""
return self.encoder.get(token, self.encoder.get(self.unk_token))
def _convert_id_to_token(self, index):
"""Converts an index (integer) in a token (str) using the vocab."""
return self.decoder.get(index, self.unk_token)
# Override since phonemes cannot be converted back to strings
def decode(self, token_ids, **kwargs):
logger.warn(
"Phonemes cannot be reliably converted to a string due to the one-many mapping, converting to tokens instead."
)
return self.convert_ids_to_tokens(token_ids)
# Override since phonemes cannot be converted back to strings
def convert_tokens_to_string(self, tokens, **kwargs):
logger.warn(
"Phonemes cannot be reliably converted to a string due to the one-many mapping, returning the tokens."
)
return tokens
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
"""
Save the vocabulary and special tokens file to a directory.
Args:
save_directory (`str`):
The directory in which to save the vocabulary.
Returns:
`Tuple(str)`: Paths to the files saved.
"""
if not os.path.isdir(save_directory):
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
return
vocab_file = os.path.join(
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
)
with open(vocab_file, "w", encoding="utf-8") as f:
f.write(json.dumps(self.get_vocab(), ensure_ascii=False))
return (vocab_file,)
def __getstate__(self):
state = self.__dict__.copy()
state["g2p"] = None
return state
def __setstate__(self, d):
self.__dict__ = d
try:
import g2p_en
self.g2p = g2p_en.G2p()
except ImportError:
raise ImportError(
"You need to install g2p-en to use FastSpeech2ConformerTokenizer. "
"See https://pypi.org/project/g2p-en/ for installation."
)
......@@ -2655,6 +2655,7 @@ class SpeechT5ForTextToSpeech(SpeechT5PreTrainedModel):
return_dict: Optional[bool] = None,
speaker_embeddings: Optional[torch.FloatTensor] = None,
labels: Optional[torch.FloatTensor] = None,
stop_labels: Optional[torch.Tensor] = None,
) -> Union[Tuple, Seq2SeqSpectrogramOutput]:
r"""
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
......@@ -2973,6 +2974,7 @@ class SpeechT5ForSpeechToSpeech(SpeechT5PreTrainedModel):
return_dict: Optional[bool] = None,
speaker_embeddings: Optional[torch.FloatTensor] = None,
labels: Optional[torch.FloatTensor] = None,
stop_labels: Optional[torch.Tensor] = None,
) -> Union[Tuple, Seq2SeqSpectrogramOutput]:
r"""
input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
......
......@@ -67,6 +67,7 @@ from .utils import (
is_flax_available,
is_fsdp_available,
is_ftfy_available,
is_g2p_en_available,
is_ipex_available,
is_jieba_available,
is_jinja_available,
......@@ -365,6 +366,13 @@ def require_fsdp(test_case, min_version: str = "1.12.0"):
)
def require_g2p_en(test_case):
"""
Decorator marking a test that requires g2p_en. These tests are skipped when SentencePiece isn't installed.
"""
return unittest.skipUnless(is_g2p_en_available(), "test requires g2p_en")(test_case)
def require_safetensors(test_case):
"""
Decorator marking a test that requires safetensors. These tests are skipped when safetensors isn't installed.
......
......@@ -123,6 +123,7 @@ from .import_utils import (
is_flax_available,
is_fsdp_available,
is_ftfy_available,
is_g2p_en_available,
is_in_notebook,
is_ipex_available,
is_jieba_available,
......
......@@ -3422,6 +3422,37 @@ class FalconPreTrainedModel(metaclass=DummyObject):
requires_backends(self, ["torch"])
FASTSPEECH2_CONFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = None
class FastSpeech2ConformerHifiGan(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class FastSpeech2ConformerModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class FastSpeech2ConformerPreTrainedModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class FastSpeech2ConformerWithHifiGan(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
FLAUBERT_PRETRAINED_MODEL_ARCHIVE_LIST = None
......
......@@ -94,6 +94,7 @@ except importlib.metadata.PackageNotFoundError:
except importlib.metadata.PackageNotFoundError:
_faiss_available = False
_ftfy_available = _is_package_available("ftfy")
_g2p_en_available = _is_package_available("g2p_en")
_ipex_available, _ipex_version = _is_package_available("intel_extension_for_pytorch", return_version=True)
_jieba_available = _is_package_available("jieba")
_jinja_available = _is_package_available("jinja2")
......@@ -444,6 +445,10 @@ def is_ftfy_available():
return _ftfy_available
def is_g2p_en_available():
return _g2p_en_available
@lru_cache()
def is_torch_tpu_available(check_device=True):
"Checks if `torch_xla` is installed and potentially if a TPU is in the environment"
......@@ -1059,6 +1064,12 @@ LEVENSHTEIN_IMPORT_ERROR = """
install python-Levenshtein`. Please note that you may need to restart your runtime after installation.
"""
# docstyle-ignore
G2P_EN_IMPORT_ERROR = """
{0} requires the g2p-en library but it was not found in your environment. You can install it with pip:
`pip install g2p-en`. Please note that you may need to restart your runtime after installation.
"""
# docstyle-ignore
PYTORCH_QUANTIZATION_IMPORT_ERROR = """
{0} requires the pytorch-quantization library but it was not found in your environment. You can install it with pip:
......@@ -1101,7 +1112,6 @@ SACREMOSES_IMPORT_ERROR = """
`pip install sacremoses`. Please note that you may need to restart your runtime after installation.
"""
# docstyle-ignore
SCIPY_IMPORT_ERROR = """
{0} requires the scipy library but it was not found in your environment. You can install it with pip:
......@@ -1225,6 +1235,7 @@ BACKENDS_MAPPING = OrderedDict(
("faiss", (is_faiss_available, FAISS_IMPORT_ERROR)),
("flax", (is_flax_available, FLAX_IMPORT_ERROR)),
("ftfy", (is_ftfy_available, FTFY_IMPORT_ERROR)),
("g2p_en", (is_g2p_en_available, G2P_EN_IMPORT_ERROR)),
("pandas", (is_pandas_available, PANDAS_IMPORT_ERROR)),
("phonemizer", (is_phonemizer_available, PHONEMIZER_IMPORT_ERROR)),
("pretty_midi", (is_pretty_midi_available, PRETTY_MIDI_IMPORT_ERROR)),
......
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for the FastSpeech2Conformer tokenizer."""
import unittest
from transformers.models.fastspeech2_conformer import FastSpeech2ConformerTokenizer
from transformers.testing_utils import require_g2p_en, slow
from ...test_tokenization_common import TokenizerTesterMixin
@require_g2p_en
class FastSpeech2ConformerTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
tokenizer_class = FastSpeech2ConformerTokenizer
test_rust_tokenizer = False
def setUp(self):
super().setUp()
tokenizer = FastSpeech2ConformerTokenizer.from_pretrained("espnet/fastspeech2_conformer")
tokenizer.save_pretrained(self.tmpdirname)
def get_input_output_texts(self, tokenizer):
input_text = "this is a test"
output_text = "this is a test"
return input_text, output_text
# Custom `get_clean_sequence` since FastSpeech2ConformerTokenizer can't decode id -> string
def get_clean_sequence(self, tokenizer, with_prefix_space=False, **kwargs): # max_length=20, min_length=5
input_text, output_text = self.get_input_output_texts(tokenizer)
ids = tokenizer.encode(output_text, add_special_tokens=False)
return output_text, ids
def test_convert_token_and_id(self):
"""Test ``_convert_token_to_id`` and ``_convert_id_to_token``."""
token = "<unk>"
token_id = 1
self.assertEqual(self.get_tokenizer()._convert_token_to_id(token), token_id)
self.assertEqual(self.get_tokenizer()._convert_id_to_token(token_id), token)
def test_get_vocab(self):
vocab_keys = list(self.get_tokenizer().get_vocab().keys())
self.assertEqual(vocab_keys[0], "<blank>")
self.assertEqual(vocab_keys[1], "<unk>")
self.assertEqual(vocab_keys[-4], "UH0")
self.assertEqual(vocab_keys[-2], "..")
self.assertEqual(vocab_keys[-1], "<sos/eos>")
self.assertEqual(len(vocab_keys), 78)
def test_vocab_size(self):
self.assertEqual(self.get_tokenizer().vocab_size, 78)
@unittest.skip(
"FastSpeech2Conformer tokenizer does not support adding tokens as they can't be added to the g2p_en backend"
)
def test_added_token_are_matched_longest_first(self):
pass
@unittest.skip(
"FastSpeech2Conformer tokenizer does not support adding tokens as they can't be added to the g2p_en backend"
)
def test_added_tokens_do_lower_case(self):
pass
@unittest.skip(
"FastSpeech2Conformer tokenizer does not support adding tokens as they can't be added to the g2p_en backend"
)
def test_tokenize_special_tokens(self):
pass
def test_full_tokenizer(self):
tokenizer = self.get_tokenizer()
tokens = tokenizer.tokenize("This is a test")
ids = [9, 12, 6, 12, 11, 2, 4, 15, 6, 4, 77]
self.assertListEqual(tokens, ["DH", "IH1", "S", "IH1", "Z", "AH0", "T", "EH1", "S", "T", "<sos/eos>"])
self.assertListEqual(tokenizer.convert_tokens_to_ids(tokens), ids)
self.assertListEqual(tokenizer.convert_ids_to_tokens(ids), tokens)
@slow
def test_tokenizer_integration(self):
# Custom test since:
# 1) This tokenizer only decodes to tokens (phonemes cannot be converted to text with complete accuracy)
# 2) Uses a sequence without numbers since espnet has different, custom number conversion.
# This tokenizer can phonemize numbers, but where in espnet "32" is phonemized as "thirty two",
# here "32" is phonemized as "thirty-two" because we haven't implemented the custom number handling.
sequences = [
"Transformers (formerly known as pytorch-transformers and pytorch-pretrained-bert) provides "
"general-purpose architectures (BERT, GPT, RoBERTa, XLM, DistilBert, XLNet...) for Natural "
"Language Understanding (NLU) and Natural Language Generation (NLG) with over thirty-two pretrained "
"models in one hundred plus languages and deep interoperability between Jax, PyTorch and TensorFlow.",
"BERT is designed to pre-train deep bidirectional representations from unlabeled text by jointly "
"conditioning on both left and right context in all layers.",
"The quick brown fox jumps over the lazy dog.",
]
tokenizer = FastSpeech2ConformerTokenizer.from_pretrained(
"espnet/fastspeech2_conformer", revision="07f9c4a2d6bbc69b277d87d2202ad1e35b05e113"
)
actual_encoding = tokenizer(sequences)
# fmt: off
expected_encoding = {
'input_ids': [
[4, 7, 60, 3, 6, 22, 30, 7, 14, 21, 11, 22, 30, 7, 14, 21, 8, 29, 3, 34, 3, 18, 11, 17, 12, 4, 21, 10, 4, 7, 60, 3, 6, 22, 30, 7, 14, 21, 11, 2, 3, 5, 17, 12, 4, 21, 10, 17, 7, 29, 4, 7, 31, 3, 5, 25, 38, 4, 17, 7, 2, 20, 32, 5, 11, 40, 15, 3, 21, 2, 8, 17, 38, 17, 2, 6, 24, 7, 10, 2, 4, 45, 10, 39, 21, 11, 25, 38, 4, 23, 37, 15, 4, 6, 23, 7, 2, 25, 38, 4, 2, 23, 11, 8, 15, 14, 11, 23, 5, 13, 6, 4, 12, 8, 4, 21, 25, 23, 11, 8, 15, 3, 39, 2, 8, 1, 22, 30, 7, 3, 18, 39, 21, 2, 8, 8, 18, 36, 37, 16, 2, 40, 62, 3, 5, 21, 6, 4, 18, 3, 5, 13, 36, 3, 8, 28, 2, 3, 5, 3, 18, 39, 21, 2, 8, 8, 18, 36, 37, 16, 2, 40, 40, 45, 3, 21, 31, 35, 2, 3, 15, 8, 36, 16, 12, 9, 34, 20, 21, 43, 38, 5, 29, 4, 28, 17, 7, 29, 4, 7, 31, 3, 5, 14, 24, 5, 2, 8, 11, 13, 3, 16, 19, 3, 26, 19, 3, 5, 7, 2, 5, 17, 8, 19, 6, 8, 18, 36, 37, 16, 2, 40, 2, 11, 2, 3, 5, 5, 27, 17, 49, 3, 4, 21, 2, 17, 21, 25, 12, 8, 2, 4, 29, 25, 13, 4, 16, 27, 3, 40, 18, 10, 6, 23, 17, 12, 4, 21, 10, 2, 3, 5, 4, 15, 3, 6, 21, 8, 46, 22, 33, 77],
[25, 38, 4, 12, 11, 5, 13, 11, 32, 3, 5, 4, 28, 17, 7, 27, 4, 7, 31, 3, 5, 27, 17, 25, 51, 5, 13, 7, 15, 10, 35, 2, 3, 2, 8, 7, 45, 17, 7, 2, 11, 2, 3, 4, 31, 35, 2, 3, 11, 22, 7, 19, 14, 2, 3, 8, 31, 25, 2, 8, 5, 4, 15, 10, 6, 4, 25, 32, 40, 55, 3, 4, 8, 29, 10, 2, 3, 5, 12, 35, 2, 3, 13, 36, 24, 3, 25, 34, 43, 8, 15, 22, 4, 2, 3, 5, 7, 32, 4, 10, 24, 3, 4, 54, 10, 6, 4, 13, 3, 30, 8, 8, 31, 21, 11, 33, 77],
[9, 2, 10, 16, 12, 10, 25, 7, 42, 3, 22, 24, 10, 6, 40, 19, 14, 17, 6, 34, 20, 21, 9, 2, 8, 31, 11, 29, 5, 30, 37, 33, 77]
],
'attention_mask': [
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
]
}
# fmt: on
actual_tokens = [tokenizer.decode(input_ids) for input_ids in expected_encoding["input_ids"]]
expected_tokens = [
[tokenizer.convert_ids_to_tokens(id) for id in sequence] for sequence in expected_encoding["input_ids"]
]
self.assertListEqual(actual_encoding["input_ids"], expected_encoding["input_ids"])
self.assertListEqual(actual_encoding["attention_mask"], expected_encoding["attention_mask"])
self.assertTrue(actual_tokens == expected_tokens)
@unittest.skip(
reason="FastSpeech2Conformer tokenizer does not support adding tokens as they can't be added to the g2p_en backend"
)
def test_add_tokens_tokenizer(self):
pass
@unittest.skip(
reason="FastSpeech2Conformer tokenizer does not support adding tokens as they can't be added to the g2p_en backend"
)
def test_add_special_tokens(self):
pass
@unittest.skip(
reason="FastSpeech2Conformer tokenizer does not support adding tokens as they can't be added to the g2p_en backend"
)
def test_added_token_serializable(self):
pass
@unittest.skip(
reason="FastSpeech2Conformer tokenizer does not support adding tokens as they can't be added to the g2p_en backend"
)
def test_save_and_load_tokenizer(self):
pass
@unittest.skip(reason="Phonemes cannot be reliably converted to string due to one-many mapping")
def test_internal_consistency(self):
pass
@unittest.skip(reason="Phonemes cannot be reliably converted to string due to one-many mapping")
def test_encode_decode_with_spaces(self):
pass
@unittest.skip(reason="Phonemes cannot be reliably converted to string due to one-many mapping")
def test_convert_tokens_to_string_format(self):
pass
@unittest.skip("FastSpeech2Conformer tokenizer does not support pairs.")
def test_maximum_encoding_length_pair_input(self):
pass
@unittest.skip(
"FastSpeech2Conformer tokenizer appends eos_token to each string it's passed, including `is_split_into_words=True`."
)
def test_pretokenized_inputs(self):
pass
@unittest.skip(
reason="g2p_en is slow is with large inputs and max encoding length is not a concern for FastSpeech2Conformer"
)
def test_maximum_encoding_length_single_input(self):
pass
......@@ -123,6 +123,7 @@ SPECIAL_CASES_TO_ALLOW.update(
"DinatConfig": True,
"DonutSwinConfig": True,
"EfficientFormerConfig": True,
"FastSpeech2ConformerConfig": True,
"FSMTConfig": True,
"JukeboxConfig": True,
"LayoutLMv2Config": True,
......
......@@ -90,6 +90,8 @@ IGNORE_NON_TESTED = PRIVATE_MODELS.copy() + [
"UMT5EncoderModel", # Building part of bigger (tested) model.
"Blip2QFormerModel", # Building part of bigger (tested) model.
"ErnieMForInformationExtraction",
"FastSpeech2ConformerHifiGan", # Already tested by SpeechT5HifiGan (# Copied from)
"FastSpeech2ConformerWithHifiGan", # Built with two smaller (tested) models.
"GraphormerDecoderHead", # Building part of bigger (tested) model.
"JukeboxVQVAE", # Building part of bigger (tested) model.
"JukeboxPrior", # Building part of bigger (tested) model.
......@@ -159,6 +161,8 @@ IGNORE_NON_AUTO_CONFIGURED = PRIVATE_MODELS.copy() + [
"Blip2QFormerModel",
"Blip2VisionModel",
"ErnieMForInformationExtraction",
"FastSpeech2ConformerHifiGan",
"FastSpeech2ConformerWithHifiGan",
"GitVisionModel",
"GraphormerModel",
"GraphormerForGraphClassification",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment