Commit 0112b0f0 authored by chenzk's avatar chenzk
Browse files

v1.0

parents
Pipeline #2394 canceled with stages
import logging
from lightning.pytorch.utilities import rank_zero_only
def get_pylogger(name: str = __name__) -> logging.Logger:
"""Initializes a multi-GPU-friendly python command line logger.
:param name: The name of the logger, defaults to ``__name__``.
:return: A logger object.
"""
logger = logging.getLogger(name)
# this ensures all logging levels get marked with the rank zero decorator
# otherwise logs would get multiplied for each GPU process in multi-GPU setup
logging_levels = ("debug", "info", "warning", "error", "exception", "fatal", "critical")
for level in logging_levels:
setattr(logger, level, rank_zero_only(getattr(logger, level)))
return logger
from pathlib import Path
from typing import Sequence
import rich
import rich.syntax
import rich.tree
from hydra.core.hydra_config import HydraConfig
from lightning.pytorch.utilities import rank_zero_only
from omegaconf import DictConfig, OmegaConf, open_dict
from rich.prompt import Prompt
from matcha.utils import pylogger
log = pylogger.get_pylogger(__name__)
@rank_zero_only
def print_config_tree(
cfg: DictConfig,
print_order: Sequence[str] = (
"data",
"model",
"callbacks",
"logger",
"trainer",
"paths",
"extras",
),
resolve: bool = False,
save_to_file: bool = False,
) -> None:
"""Prints the contents of a DictConfig as a tree structure using the Rich library.
:param cfg: A DictConfig composed by Hydra.
:param print_order: Determines in what order config components are printed. Default is ``("data", "model",
"callbacks", "logger", "trainer", "paths", "extras")``.
:param resolve: Whether to resolve reference fields of DictConfig. Default is ``False``.
:param save_to_file: Whether to export config to the hydra output folder. Default is ``False``.
"""
style = "dim"
tree = rich.tree.Tree("CONFIG", style=style, guide_style=style)
queue = []
# add fields from `print_order` to queue
for field in print_order:
_ = (
queue.append(field)
if field in cfg
else log.warning(f"Field '{field}' not found in config. Skipping '{field}' config printing...")
)
# add all the other fields to queue (not specified in `print_order`)
for field in cfg:
if field not in queue:
queue.append(field)
# generate config tree from queue
for field in queue:
branch = tree.add(field, style=style, guide_style=style)
config_group = cfg[field]
if isinstance(config_group, DictConfig):
branch_content = OmegaConf.to_yaml(config_group, resolve=resolve)
else:
branch_content = str(config_group)
branch.add(rich.syntax.Syntax(branch_content, "yaml"))
# print config tree
rich.print(tree)
# save config tree to file
if save_to_file:
with open(Path(cfg.paths.output_dir, "config_tree.log"), "w", encoding="utf-8") as file:
rich.print(tree, file=file)
@rank_zero_only
def enforce_tags(cfg: DictConfig, save_to_file: bool = False) -> None:
"""Prompts user to input tags from command line if no tags are provided in config.
:param cfg: A DictConfig composed by Hydra.
:param save_to_file: Whether to export tags to the hydra output folder. Default is ``False``.
"""
if not cfg.get("tags"):
if "id" in HydraConfig().cfg.hydra.job:
raise ValueError("Specify tags before launching a multirun!")
log.warning("No tags provided in config. Prompting user to input tags...")
tags = Prompt.ask("Enter a list of comma separated tags", default="dev")
tags = [t.strip() for t in tags.split(",") if t != ""]
with open_dict(cfg):
cfg.tags = tags
log.info(f"Tags: {cfg.tags}")
if save_to_file:
with open(Path(cfg.paths.output_dir, "tags.log"), "w", encoding="utf-8") as file:
rich.print(cfg.tags, file=file)
import os
import sys
import warnings
from importlib.util import find_spec
from math import ceil
from pathlib import Path
from typing import Any, Callable, Dict, Tuple
import gdown
import matplotlib.pyplot as plt
import numpy as np
import torch
import wget
from omegaconf import DictConfig
from matcha.utils import pylogger, rich_utils
log = pylogger.get_pylogger(__name__)
def extras(cfg: DictConfig) -> None:
"""Applies optional utilities before the task is started.
Utilities:
- Ignoring python warnings
- Setting tags from command line
- Rich config printing
:param cfg: A DictConfig object containing the config tree.
"""
# return if no `extras` config
if not cfg.get("extras"):
log.warning("Extras config not found! <cfg.extras=null>")
return
# disable python warnings
if cfg.extras.get("ignore_warnings"):
log.info("Disabling python warnings! <cfg.extras.ignore_warnings=True>")
warnings.filterwarnings("ignore")
# prompt user to input tags from command line if none are provided in the config
if cfg.extras.get("enforce_tags"):
log.info("Enforcing tags! <cfg.extras.enforce_tags=True>")
rich_utils.enforce_tags(cfg, save_to_file=True)
# pretty print config tree using Rich library
if cfg.extras.get("print_config"):
log.info("Printing config tree with Rich! <cfg.extras.print_config=True>")
rich_utils.print_config_tree(cfg, resolve=True, save_to_file=True)
def task_wrapper(task_func: Callable) -> Callable:
"""Optional decorator that controls the failure behavior when executing the task function.
This wrapper can be used to:
- make sure loggers are closed even if the task function raises an exception (prevents multirun failure)
- save the exception to a `.log` file
- mark the run as failed with a dedicated file in the `logs/` folder (so we can find and rerun it later)
- etc. (adjust depending on your needs)
Example:
```
@utils.task_wrapper
def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
...
return metric_dict, object_dict
```
:param task_func: The task function to be wrapped.
:return: The wrapped task function.
"""
def wrap(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
# execute the task
try:
metric_dict, object_dict = task_func(cfg=cfg)
# things to do if exception occurs
except Exception as ex:
# save exception to `.log` file
log.exception("")
# some hyperparameter combinations might be invalid or cause out-of-memory errors
# so when using hparam search plugins like Optuna, you might want to disable
# raising the below exception to avoid multirun failure
raise ex
# things to always do after either success or exception
finally:
# display output dir path in terminal
log.info(f"Output dir: {cfg.paths.output_dir}")
# always close wandb run (even if exception occurs so multirun won't fail)
if find_spec("wandb"): # check if wandb is installed
import wandb
if wandb.run:
log.info("Closing wandb!")
wandb.finish()
return metric_dict, object_dict
return wrap
def get_metric_value(metric_dict: Dict[str, Any], metric_name: str) -> float:
"""Safely retrieves value of the metric logged in LightningModule.
:param metric_dict: A dict containing metric values.
:param metric_name: The name of the metric to retrieve.
:return: The value of the metric.
"""
if not metric_name:
log.info("Metric name is None! Skipping metric value retrieval...")
return None
if metric_name not in metric_dict:
raise ValueError(
f"Metric value not found! <metric_name={metric_name}>\n"
"Make sure metric name logged in LightningModule is correct!\n"
"Make sure `optimized_metric` name in `hparams_search` config is correct!"
)
metric_value = metric_dict[metric_name].item()
log.info(f"Retrieved metric value! <{metric_name}={metric_value}>")
return metric_value
def intersperse(lst, item):
# Adds blank symbol
result = [item] * (len(lst) * 2 + 1)
result[1::2] = lst
return result
def save_figure_to_numpy(fig):
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
return data
def plot_tensor(tensor):
plt.style.use("default")
fig, ax = plt.subplots(figsize=(12, 3))
im = ax.imshow(tensor, aspect="auto", origin="lower", interpolation="none")
plt.colorbar(im, ax=ax)
plt.tight_layout()
fig.canvas.draw()
data = save_figure_to_numpy(fig)
plt.close()
return data
def save_plot(tensor, savepath):
plt.style.use("default")
fig, ax = plt.subplots(figsize=(12, 3))
im = ax.imshow(tensor, aspect="auto", origin="lower", interpolation="none")
plt.colorbar(im, ax=ax)
plt.tight_layout()
fig.canvas.draw()
plt.savefig(savepath)
plt.close()
def to_numpy(tensor):
if isinstance(tensor, np.ndarray):
return tensor
elif isinstance(tensor, torch.Tensor):
return tensor.detach().cpu().numpy()
elif isinstance(tensor, list):
return np.array(tensor)
else:
raise TypeError("Unsupported type for conversion to numpy array")
def get_user_data_dir(appname="matcha_tts"):
"""
Args:
appname (str): Name of application
Returns:
Path: path to user data directory
"""
MATCHA_HOME = os.environ.get("MATCHA_HOME")
if MATCHA_HOME is not None:
ans = Path(MATCHA_HOME).expanduser().resolve(strict=False)
elif sys.platform == "win32":
import winreg # pylint: disable=import-outside-toplevel
key = winreg.OpenKey(
winreg.HKEY_CURRENT_USER,
r"Software\Microsoft\Windows\CurrentVersion\Explorer\Shell Folders",
)
dir_, _ = winreg.QueryValueEx(key, "Local AppData")
ans = Path(dir_).resolve(strict=False)
elif sys.platform == "darwin":
ans = Path("~/Library/Application Support/").expanduser()
else:
ans = Path.home().joinpath(".local/share")
final_path = ans.joinpath(appname)
final_path.mkdir(parents=True, exist_ok=True)
return final_path
def assert_model_downloaded(checkpoint_path, url, use_wget=True):
if Path(checkpoint_path).exists():
log.debug(f"[+] Model already present at {checkpoint_path}!")
print(f"[+] Model already present at {checkpoint_path}!")
return
log.info(f"[-] Model not found at {checkpoint_path}! Will download it")
print(f"[-] Model not found at {checkpoint_path}! Will download it")
checkpoint_path = str(checkpoint_path)
if not use_wget:
gdown.download(url=url, output=checkpoint_path, quiet=False, fuzzy=True)
else:
wget.download(url=url, out=checkpoint_path)
def get_phoneme_durations(durations, phones):
prev = durations[0]
merged_durations = []
# Convolve with stride 2
for i in range(1, len(durations), 2):
if i == len(durations) - 2:
# if it is last take full value
next_half = durations[i + 1]
else:
next_half = ceil(durations[i + 1] / 2)
curr = prev + durations[i] + next_half
prev = durations[i + 1] - next_half
merged_durations.append(curr)
assert len(phones) == len(merged_durations)
assert len(merged_durations) == (len(durations) - 1) // 2
merged_durations = torch.cumsum(torch.tensor(merged_durations), 0, dtype=torch.long)
start = torch.tensor(0)
duration_json = []
for i, duration in enumerate(merged_durations):
duration_json.append(
{
phones[i]: {
"starttime": start.item(),
"endtime": duration.item(),
"duration": duration.item() - start.item(),
}
}
)
start = duration
assert list(duration_json[-1].values())[0]["endtime"] == sum(
durations
), f"{list(duration_json[-1].values())[0]['endtime'], sum(durations)}"
return duration_json
[build-system]
requires = ["setuptools", "wheel", "cython==0.29.35", "numpy==1.24.3", "packaging"]
[tool.black]
line-length = 120
target-version = ['py310']
exclude = '''
(
/(
\.eggs # exclude a few common directories in the
| \.git # root of the project
| \.hg
| \.mypy_cache
| \.tox
| \.venv
| _build
| buck-out
| build
| dist
)/
| foo.py # also separately exclude a file named foo.py in
# the root of the project
)
'''
[tool.pytest.ini_options]
addopts = [
"--color=yes",
"--durations=0",
"--strict-markers",
"--doctest-modules",
]
filterwarnings = [
"ignore::DeprecationWarning",
"ignore::UserWarning",
]
log_cli = "True"
markers = [
"slow: slow tests",
]
minversion = "6.0"
testpaths = "tests/"
[tool.coverage.report]
exclude_lines = [
"pragma: nocover",
"raise NotImplementedError",
"raise NotImplementedError()",
"if __name__ == .__main__.:",
]
# --------- pytorch --------- #
torch>=2.0.0
torchvision>=0.15.0
lightning>=2.0.0
torchmetrics>=0.11.4
# --------- hydra --------- #
hydra-core==1.3.2
hydra-colorlog==1.2.0
hydra-optuna-sweeper==1.2.0
# --------- loggers --------- #
# wandb
# neptune-client
# mlflow
# comet-ml
# aim>=3.16.2 # no lower than 3.16.2, see https://github.com/aimhubio/aim/issues/2550
# --------- others --------- #
rootutils # standardizing the project root setup
pre-commit # hooks for applying linters on commit
rich # beautiful text formatting in terminal
pytest # tests
# sh # for running bash commands in some tests (linux/macos only)
phonemizer # phonemization of text
tensorboard
librosa
Cython
numpy
einops
inflect
Unidecode
scipy
torchaudio
matplotlib
pandas
conformer==0.3.2
diffusers # developed using version ==0.25.0
notebook
ipywidgets
gradio==3.43.2
gdown
wget
seaborn
#!/bin/bash
# Schedule execution of many runs
# Run from root folder with: bash scripts/schedule.sh
python src/train.py trainer.max_epochs=5 logger=csv
python src/train.py trainer.max_epochs=10 logger=csv
#!/usr/bin/env python
import os
import numpy
from Cython.Build import cythonize
from setuptools import Extension, find_packages, setup
exts = [
Extension(
name="matcha.utils.monotonic_align.core",
sources=["matcha/utils/monotonic_align/core.pyx"],
)
]
with open("README.md", encoding="utf-8") as readme_file:
README = readme_file.read()
cwd = os.path.dirname(os.path.abspath(__file__))
with open(os.path.join(cwd, "matcha", "VERSION"), encoding="utf-8") as fin:
version = fin.read().strip()
def get_requires():
requirements = os.path.join(os.path.dirname(__file__), "requirements.txt")
with open(requirements, encoding="utf-8") as reqfile:
return [str(r).strip() for r in reqfile]
setup(
name="matcha-tts",
version=version,
description="🍵 Matcha-TTS: A fast TTS architecture with conditional flow matching",
long_description=README,
long_description_content_type="text/markdown",
author="Shivam Mehta",
author_email="shivam.mehta25@gmail.com",
url="https://shivammehta25.github.io/Matcha-TTS",
install_requires=get_requires(),
include_dirs=[numpy.get_include()],
include_package_data=True,
packages=find_packages(exclude=["tests", "tests/*", "examples", "examples/*"]),
# use this to customize global commands available in the terminal after installing the package
entry_points={
"console_scripts": [
"matcha-data-stats=matcha.utils.generate_data_statistics:main",
"matcha-tts=matcha.cli:cli",
"matcha-tts-app=matcha.app:main",
"matcha-tts-get-durations=matcha.utils.get_durations_from_trained_model:main",
]
},
ext_modules=cythonize(exts, language_level=3),
python_requires=">=3.9.0",
)
This source diff could not be displayed because it is too large. You can view the blob instead.
#!/usr/bin/env python3
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
import torch
from tqdm import tqdm
import numpy as np
import torchaudio
from inspiremusic.utils.audio_utils import normalize, split_wav_into_chunks
from inspiremusic.music_tokenizer.vqvae import VQVAE
import time
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
def main(args):
audio_min_length = 1.0
audio_max_length = 30.0
max_chunk_size = int(args.sample_rate * audio_max_length)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
utt2wav = {}
with open('{}/wav.scp'.format(args.dir)) as f:
for l in f:
l = l.replace('\n', '').split()
utt2wav[l[0]] = l[1]
model = VQVAE(args.config_path, args.ckpt_path, with_encoder=True)
model.cuda()
model.eval()
utt2acoustic_token = {}
start_time = time.time()
for utt in tqdm(utt2wav.keys()):
audio, sample_rate = torchaudio.load(utt2wav[utt])
if sample_rate != args.sample_rate:
audio = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=args.sample_rate)(audio)
audio_length = audio.shape[1]
if audio_length > args.sample_rate * audio_min_length:
if audio_length > max_chunk_size:
wav_chunks = split_wav_into_chunks(audio_length, audio, max_chunk_size)
for chunk in wav_chunks:
chunk = torch.tensor(chunk, dtype=torch.float32).to(device)
acoustic_token = model.encode(chunk)
if acoustic_token.is_cuda:
acoustic_token = acoustic_token.cpu()
acoustic_token = acoustic_token.numpy().astype(np.int16)
if utt not in utt2acoustic_token.keys():
utt2acoustic_token[utt] = acoustic_token
else:
utt2acoustic_token[utt] = np.concatenate((utt2acoustic_token[utt], acoustic_token), axis=1)
else:
audio = torch.tensor(audio, dtype=torch.float32).to(device)
acoustic_token = model.encode(audio)
if acoustic_token.is_cuda:
acoustic_token = acoustic_token.cpu()
acoustic_token = acoustic_token.numpy().astype(np.int16)
utt2acoustic_token[utt] = acoustic_token
else:
logging.warning('This audio length is too short.')
torch.save(utt2acoustic_token, '{}/utt2acoustic_token.pt'.format(args.dir))
logging.info('spend time {}'.format(time.time() - start_time))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--dir',
type=str)
parser.add_argument('--config_path',
type=str, default="pretrained_models/InspireMusic-Base/music_tokenizer/config.json")
parser.add_argument('--ckpt_path',
type=str, default="pretrained_models/InspireMusic-Base/music_tokenizer/model.pt")
parser.add_argument('--sample_rate',
default=24000,
type=int)
args = parser.parse_args()
main(args)
#!/usr/bin/env python3
# Copyright (c) 2024 Alibaba Inc
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import torch
import torchaudio
from tqdm import tqdm
import onnxruntime
import torchaudio.compliance.kaldi as kaldi
def main(args):
utt2wav, utt2spk = {}, {}
with open('{}/wav.scp'.format(args.dir)) as f:
for l in f:
l = l.replace('\n', '').split()
utt2wav[l[0]] = l[1]
with open('{}/utt2spk'.format(args.dir)) as f:
for l in f:
l = l.replace('\n', '').split()
utt2spk[l[0]] = l[1]
option = onnxruntime.SessionOptions()
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
option.intra_op_num_threads = 1
providers = ["CPUExecutionProvider"]
ort_session = onnxruntime.InferenceSession(args.onnx_path, sess_options=option, providers=providers)
utt2embedding, spk2embedding = {}, {}
for utt in tqdm(utt2wav.keys()):
audio, sample_rate = torchaudio.load(utt2wav[utt])
if sample_rate != 16000:
audio = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(audio)
feat = kaldi.fbank(audio,
num_mel_bins=80,
dither=0,
sample_frequency=16000)
feat = feat - feat.mean(dim=0, keepdim=True)
embedding = ort_session.run(None, {ort_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0].flatten().tolist()
utt2embedding[utt] = embedding
spk = utt2spk[utt]
if spk not in spk2embedding:
spk2embedding[spk] = []
spk2embedding[spk].append(embedding)
for k, v in spk2embedding.items():
spk2embedding[k] = torch.tensor(v).mean(dim=0).tolist()
torch.save(utt2embedding, '{}/utt2embedding.pt'.format(args.dir))
torch.save(spk2embedding, '{}/spk2embedding.pt'.format(args.dir))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--dir',
type=str)
parser.add_argument('--onnx_path',
type=str)
args = parser.parse_args()
main(args)
#!/usr/bin/env python3
# Copyright (c) 2024 Alibaba Inc
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
import torch
from tqdm import tqdm
import numpy as np
import torchaudio
import time
import os
from inspiremusic.wavtokenizer.decoder.pretrained import WavTokenizer
from inspiremusic.utils.audio_utils import split_wav_into_chunks
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
def main(args):
audio_min_length = 1.0
audio_max_length = 30.0
max_chunk_size = int(args.sample_rate * audio_max_length)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
utt2wav = {}
with open('{}/wav.scp'.format(args.dir)) as f:
for l in f:
l = l.replace('\n', '').split()
utt2wav[l[0]] = l[1]
wavtokenizer = WavTokenizer.from_pretrained_feat(args.config_path, args.ckpt_path).to(device)
bandwidth_id = torch.tensor([0]).to(device)
start_time = time.time()
utt2semantic_token = {}
for utt in tqdm(utt2wav.keys()):
audio, sample_rate = torchaudio.load(utt2wav[utt])
if sample_rate != args.sample_rate:
audio = torchaudio.functional.resample(audio, orig_freq=sample_rate, new_freq=args.sample_rate)
audio_length = audio.shape[1]
if audio_length > args.sample_rate * audio_min_length:
if audio_length > max_chunk_size:
wav_batch = split_wav_into_chunks(audio_length, audio, max_chunk_size)
for chunk in wav_batch:
chunk = torch.tensor(chunk, dtype=torch.float32).to(device)
_, semantic_token = wavtokenizer.encode_infer(chunk, bandwidth_id=bandwidth_id)
if semantic_token.is_cuda:
semantic_token = semantic_token.cpu()
semantic_token = semantic_token.squeeze(0).numpy().astype(np.int16)
if utt not in utt2semantic_token.keys():
utt2semantic_token[utt] = semantic_token
else:
utt2semantic_token[utt] = np.concatenate((utt2semantic_token[utt], semantic_token), axis=1)
else:
audio = torch.tensor(audio, dtype=torch.float32).to(device)
_, semantic_token = wavtokenizer.encode_infer(audio, bandwidth_id=bandwidth_id)
if semantic_token.is_cuda:
semantic_token = semantic_token.cpu()
semantic_token = semantic_token.squeeze(0).numpy().astype(np.int16)
utt2semantic_token[utt] = semantic_token
else:
logging.warning('This audio length is too short.')
torch.save(utt2semantic_token, '{}/utt2semantic_token.pt'.format(args.dir))
logging.info('spend time {}'.format(time.time() - start_time))
def reconstruct(semantic_token_file, config_path, ckpt_path, outdir, sample_rate=24000):
if not os.path.isdir(outdir):
os.makedirs(outdir, exist_ok=True)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
bandwidth_id = torch.tensor([0]).to(device)
wavtokenizer = WavTokenizer.from_pretrained_feat(config_path, ckpt_path).to(device)
utt2semantic_token = torch.load(semantic_token_file)
for utt in tqdm(utt2semantic_token.keys()):
token = utt2semantic_token[utt]
new_tensor = torch.tensor(token).to(device).unsqueeze(0)
features = wavtokenizer.codes_to_features(new_tensor)
wav = wavtokenizer.decode(features, bandwidth_id=bandwidth_id)
wav = wav.cpu().detach()
torchaudio.save(outdir + "/" + utt + ".wav", wav, sample_rate=sample_rate, encoding='PCM_S', bits_per_sample=16)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--dir',
type=str)
parser.add_argument('--config_path',
type=str, default="pretrained_models/InspireMusic-Base/wavtokenizer/config.yaml")
parser.add_argument('--ckpt_path',
type=str, default="pretrained_models/InspireMusic-Base/wavtokenizer/model.pt")
parser.add_argument('--sample_rate',
default=24000,
type=int)
parser.add_argument('--outwavdir',
type=str, default="./exp/wavs")
args = parser.parse_args()
main(args)
#!/usr/bin/env python3
# Copyright (c) 2024 Alibaba Inc
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
import torch
from tqdm import tqdm
import onnxruntime
import numpy as np
import torchaudio
import whisper
def main(args):
utt2wav = {}
with open('{}/wav.scp'.format(args.dir)) as f:
for l in f:
l = l.replace('\n', '').split()
utt2wav[l[0]] = l[1]
option = onnxruntime.SessionOptions()
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
option.intra_op_num_threads = 1
providers = ["CUDAExecutionProvider"]
ort_session = onnxruntime.InferenceSession(args.onnx_path, sess_options=option, providers=providers)
utt2speech_token = {}
for utt in tqdm(utt2wav.keys()):
audio, sample_rate = torchaudio.load(utt2wav[utt])
if sample_rate != 16000:
audio = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(audio)
if audio.shape[1] / 16000 > 30:
logging.warning('do not support extract speech token for audio longer than 30s')
speech_token = []
else:
feat = whisper.log_mel_spectrogram(audio, n_mels=128)
speech_token = ort_session.run(None, {ort_session.get_inputs()[0].name: feat.detach().cpu().numpy(),
ort_session.get_inputs()[1].name: np.array([feat.shape[2]], dtype=np.int32)})[0].flatten().tolist()
utt2speech_token[utt] = speech_token
torch.save(utt2speech_token, '{}/utt2speech_token.pt'.format(args.dir))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--dir',
type=str)
parser.add_argument('--onnx_path',
type=str)
args = parser.parse_args()
main(args)
#!/usr/bin/env python3
# Copyright (c) 2024 Alibaba Inc (authors: Chong Zhang)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
import os
import json
from tqdm import tqdm
import pandas as pd
import multiprocessing
import time
import torch
import numpy as np
import random
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
def job(utt_list, token_list, parquet_file, utt2text, utt2time, utt2chorus, semantic_token_list):
start_time = time.time()
text_list = [utt2text[utt] for utt in utt_list]
time_start = [utt2time[utt][0] for utt in utt_list]
time_end = [utt2time[utt][1] for utt in utt_list]
chorus_list = [utt2chorus[utt] for utt in utt_list]
print(len(token_list))
print(len(semantic_token_list))
try:
df = pd.DataFrame()
df['utt'] = utt_list
df['text'] = text_list
df['chorus'] = chorus_list
df['time_start'] = time_start
df['time_end'] = time_end
df["semantic_token"] = semantic_token_list
df["acoustic_token"] = token_list
logging.info(f'Starting to save parquet file: {parquet_file}')
df.to_parquet(parquet_file)
logging.info(f'Successfully saved parquet file: {parquet_file}')
except Exception as e:
logging(f'Error saving parquet file: {e}')
logging.info('Processing time {}s'.format(time.time() - start_time))
def text_only_job(utt_list, parquet_file, utt2text, utt2time, utt2chorus):
start_time = time.time()
text_list = [utt2text[utt] for utt in utt_list]
time_start = [utt2time[utt][0] for utt in utt_list]
time_end = [utt2time[utt][1] for utt in utt_list]
chorus_list = [utt2chorus[utt] for utt in utt_list]
try:
# 保存到parquet
df = pd.DataFrame()
df['utt'] = utt_list
df['text'] = text_list
df['chorus'] = chorus_list
df['time_start'] = time_start
df['time_end'] = time_end
logging.info(f'Starting to save parquet file: {parquet_file}')
df.to_parquet(parquet_file)
logging.info(f'Successfully saved parquet file: {parquet_file}')
except Exception as e:
logging(f'Error saving parquet file: {e}')
logging.info('Processing time {}s'.format(time.time() - start_time))
def parse_trans(line):
music_structure_labels = ["intro", "verse1", "chorus", "verse2", "verse", "outro"]
uid,l = line.strip().split("\t")
split = l.split("|><|")
time_start = float(split[0].replace("<|",""))
time_end = float(split[-1].replace("|>", ""))
chorus = split[1]
if split[2] == "lyrics":
text = "<|lyrics|> " + split[3]
elif split[2] == "music":
text = "<|music|>"
else:
text = split[2]
if chorus not in music_structure_labels:
chorus = random.choice(music_structure_labels)
if chorus in ["verse1", "verse2"]:
chorus = "verse"
if len(split) < 4 or time_start >= time_end:
print(line, split, time_start, time_end)
return None
if time_start < 0:
time_start = 0.0
return (uid, time_start, time_end, chorus, text)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--num_utts_per_parquet',
type=int,
default=1000,
required=False,
help='num utts per parquet')
parser.add_argument('--num_processes',
type=int,
default=1,
required=False,
help='num processes for make parquets')
parser.add_argument('--src_dir',
type=str, required=True)
parser.add_argument('--des_dir',
type=str, required=True)
parser.add_argument('--semantic_token_dir',
type=str,
default=None, required=False)
parser.add_argument('--acoustic_token_dir',
type=str,
default=None, required=False)
args = parser.parse_args()
parquet_list = []
cnt = 0
utt2text = {}
utt2time = {}
utt2chorus = {}
uid_list = []
print(args)
if not os.path.exists(f'{args.src_dir}/text'):
raise FileNotFoundError(
f"Please check: {args.src_dir}/text file does not exist")
with open(f'{args.src_dir}/text', 'r') as f:
for l in f:
res = parse_trans(l)
if res is None:
continue
uid, time_start, time_end, chorus, text = res
uid_list.append(uid)
utt2time[uid] = (time_start, time_end)
utt2chorus[uid] = chorus
utt2text[uid] = text
utt2semantic_token = None
utt2acoustic_token = None
if args.semantic_token_dir is not None:
utt2semantic_token = {}
for fn in os.listdir(args.semantic_token_dir):
if fn.endswith("pt") and fn.startswith("utt2semantic_"):
print(f"Starting {fn}")
try:
utt2semantic_token.update(
torch.load('{}/{}'.format(args.semantic_token_dir, fn)))
except:
print('{}/{} failed'.format(args.semantic_token_dir, fn))
pass
print(len(utt2semantic_token))
# # Using process pool to speedup
pool = multiprocessing.Pool(processes=args.num_processes)
if args.acoustic_token_dir is not None:
for fn in os.listdir(args.acoustic_token_dir):
if fn.endswith("pt") and fn.startswith("utt2acoustic_"):
print(f"Starting {fn}")
utt2token = torch.load(
'{}/{}'.format(args.acoustic_token_dir, fn))
utts = [utt for utt in utt2token.keys() if utt in utt2text.keys()]
if utt2semantic_token:
utts = [utt for utt in utts if
utt in utt2semantic_token.keys()]
if len(utts) == 0:
print("0 lines remained.")
continue
if isinstance(utt2token[utts[0]], np.ndarray):
token_lists = [utt2token[utt][0].tolist() for utt in utts]
else:
token_lists = [
utt2token[utt].tolist() if utt2token[
utt].dim() == 2 else
utt2token[utt][0].tolist()
for utt in utts
]
print(len(token_lists))
semantic_token_lists = [
utt2semantic_token[utt].tolist() if not isinstance(
utt2semantic_token[utt], list) else
utt2semantic_token[utt] for utt in
utts] if utt2semantic_token else None
for i, j in enumerate(
range(0, len(utts), args.num_utts_per_parquet)):
parquet_file = os.path.join(args.des_dir,
'parquet_{:09d}.tar'.format(
cnt + i))
print(f"process {parquet_file}")
parquet_list.append(parquet_file)
token_list = token_lists[j: j + args.num_utts_per_parquet]
if semantic_token_lists:
semantic_token_list = semantic_token_lists[
j: j + args.num_utts_per_parquet]
else:
semantic_token_list = None
pool.apply_async(job, (
utts[j: j + args.num_utts_per_parquet], token_list,
parquet_file, utt2text, utt2time, utt2chorus,
semantic_token_list))
cnt += i
if args.semantic_token_dir is None and args.acoustic_token_dir is None:
for i, j in enumerate(
range(0, len(uid_list), args.num_utts_per_parquet)):
parquet_file = os.path.join(args.des_dir,
'parquet_{:09d}.tar'.format(cnt + i))
print(f"process {parquet_file}")
parquet_list.append(parquet_file)
pool.apply_async(text_only_job, (
uid_list[j: j + args.num_utts_per_parquet], parquet_file, utt2text,
utt2time, utt2chorus))
cnt += i
pool.close()
pool.join()
print("DONE")
with open('{}/data.list'.format(args.des_dir), 'w', encoding='utf8') as f1:
for name in parquet_list:
f1.write(name + '\n')
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment