Commit 799a38c5 authored by chenzk's avatar chenzk
Browse files

v1.0

parents
Pipeline #616 failed with stages
in 0 seconds
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import os
from pathlib import Path
from typing import Optional, List, Dict
import zipfile
import tempfile
from dataclasses import dataclass
from itertools import groupby
import torch
import torch.nn.functional as F
import numpy as np
from tqdm import tqdm
from examples.speech_to_text.data_utils import load_tsv_to_dicts
from fairseq.data.audio.audio_utils import TTSSpectrogram, TTSMelScale
def trim_or_pad_to_target_length(
data_1d_or_2d: np.ndarray, target_length: int
) -> np.ndarray:
assert len(data_1d_or_2d.shape) in {1, 2}
delta = data_1d_or_2d.shape[0] - target_length
if delta >= 0: # trim if being longer
data_1d_or_2d = data_1d_or_2d[: target_length]
else: # pad if being shorter
if len(data_1d_or_2d.shape) == 1:
data_1d_or_2d = np.concatenate(
[data_1d_or_2d, np.zeros(-delta)], axis=0
)
else:
data_1d_or_2d = np.concatenate(
[data_1d_or_2d, np.zeros((-delta, data_1d_or_2d.shape[1]))],
axis=0
)
return data_1d_or_2d
def extract_logmel_spectrogram(
waveform: torch.Tensor, sample_rate: int,
output_path: Optional[Path] = None, win_length: int = 1024,
hop_length: int = 256, n_fft: int = 1024,
win_fn: callable = torch.hann_window, n_mels: int = 80,
f_min: float = 0., f_max: float = 8000, eps: float = 1e-5,
overwrite: bool = False, target_length: Optional[int] = None
):
if output_path is not None and output_path.is_file() and not overwrite:
return
spectrogram_transform = TTSSpectrogram(
n_fft=n_fft, win_length=win_length, hop_length=hop_length,
window_fn=win_fn
)
mel_scale_transform = TTSMelScale(
n_mels=n_mels, sample_rate=sample_rate, f_min=f_min, f_max=f_max,
n_stft=n_fft // 2 + 1
)
spectrogram = spectrogram_transform(waveform)
mel_spec = mel_scale_transform(spectrogram)
logmel_spec = torch.clamp(mel_spec, min=eps).log()
assert len(logmel_spec.shape) == 3 and logmel_spec.shape[0] == 1
logmel_spec = logmel_spec.squeeze().t() # D x T -> T x D
if target_length is not None:
trim_or_pad_to_target_length(logmel_spec, target_length)
if output_path is not None:
np.save(output_path.as_posix(), logmel_spec)
else:
return logmel_spec
def extract_pitch(
waveform: torch.Tensor, sample_rate: int,
output_path: Optional[Path] = None, hop_length: int = 256,
log_scale: bool = True, phoneme_durations: Optional[List[int]] = None
):
if output_path is not None and output_path.is_file():
return
try:
import pyworld
except ImportError:
raise ImportError("Please install PyWORLD: pip install pyworld")
_waveform = waveform.squeeze(0).double().numpy()
pitch, t = pyworld.dio(
_waveform, sample_rate, frame_period=hop_length / sample_rate * 1000
)
pitch = pyworld.stonemask(_waveform, pitch, t, sample_rate)
if phoneme_durations is not None:
pitch = trim_or_pad_to_target_length(pitch, sum(phoneme_durations))
try:
from scipy.interpolate import interp1d
except ImportError:
raise ImportError("Please install SciPy: pip install scipy")
nonzero_ids = np.where(pitch != 0)[0]
interp_fn = interp1d(
nonzero_ids,
pitch[nonzero_ids],
fill_value=(pitch[nonzero_ids[0]], pitch[nonzero_ids[-1]]),
bounds_error=False,
)
pitch = interp_fn(np.arange(0, len(pitch)))
d_cumsum = np.cumsum(np.concatenate([np.array([0]), phoneme_durations]))
pitch = np.array(
[
np.mean(pitch[d_cumsum[i-1]: d_cumsum[i]])
for i in range(1, len(d_cumsum))
]
)
assert len(pitch) == len(phoneme_durations)
if log_scale:
pitch = np.log(pitch + 1)
if output_path is not None:
np.save(output_path.as_posix(), pitch)
else:
return pitch
def extract_energy(
waveform: torch.Tensor, output_path: Optional[Path] = None,
hop_length: int = 256, n_fft: int = 1024, log_scale: bool = True,
phoneme_durations: Optional[List[int]] = None
):
if output_path is not None and output_path.is_file():
return
assert len(waveform.shape) == 2 and waveform.shape[0] == 1
waveform = waveform.view(1, 1, waveform.shape[1])
waveform = F.pad(
waveform.unsqueeze(1), [n_fft // 2, n_fft // 2, 0, 0],
mode="reflect"
)
waveform = waveform.squeeze(1)
fourier_basis = np.fft.fft(np.eye(n_fft))
cutoff = int((n_fft / 2 + 1))
fourier_basis = np.vstack(
[np.real(fourier_basis[:cutoff, :]),
np.imag(fourier_basis[:cutoff, :])]
)
forward_basis = torch.FloatTensor(fourier_basis[:, None, :])
forward_transform = F.conv1d(
waveform, forward_basis, stride=hop_length, padding=0
)
real_part = forward_transform[:, :cutoff, :]
imag_part = forward_transform[:, cutoff:, :]
magnitude = torch.sqrt(real_part ** 2 + imag_part ** 2)
energy = torch.norm(magnitude, dim=1).squeeze(0).numpy()
if phoneme_durations is not None:
energy = trim_or_pad_to_target_length(energy, sum(phoneme_durations))
d_cumsum = np.cumsum(np.concatenate([np.array([0]), phoneme_durations]))
energy = np.array(
[
np.mean(energy[d_cumsum[i - 1]: d_cumsum[i]])
for i in range(1, len(d_cumsum))
]
)
assert len(energy) == len(phoneme_durations)
if log_scale:
energy = np.log(energy + 1)
if output_path is not None:
np.save(output_path.as_posix(), energy)
else:
return energy
def get_global_cmvn(feature_root: Path, output_path: Optional[Path] = None):
mean_x, mean_x2, n_frames = None, None, 0
feature_paths = feature_root.glob("*.npy")
for p in tqdm(feature_paths):
with open(p, 'rb') as f:
frames = np.load(f).squeeze()
n_frames += frames.shape[0]
cur_mean_x = frames.sum(axis=0)
if mean_x is None:
mean_x = cur_mean_x
else:
mean_x += cur_mean_x
cur_mean_x2 = (frames ** 2).sum(axis=0)
if mean_x2 is None:
mean_x2 = cur_mean_x2
else:
mean_x2 += cur_mean_x2
mean_x /= n_frames
mean_x2 /= n_frames
var_x = mean_x2 - mean_x ** 2
std_x = np.sqrt(np.maximum(var_x, 1e-10))
if output_path is not None:
with open(output_path, 'wb') as f:
np.savez(f, mean=mean_x, std=std_x)
else:
return {"mean": mean_x, "std": std_x}
def ipa_phonemize(text, lang="en-us", use_g2p=False):
if use_g2p:
assert lang == "en-us", "g2pE phonemizer only works for en-us"
try:
from g2p_en import G2p
g2p = G2p()
return " ".join("|" if p == " " else p for p in g2p(text))
except ImportError:
raise ImportError(
"Please install phonemizer: pip install g2p_en"
)
else:
try:
from phonemizer import phonemize
from phonemizer.separator import Separator
return phonemize(
text, backend='espeak', language=lang,
separator=Separator(word="| ", phone=" ")
)
except ImportError:
raise ImportError(
"Please install phonemizer: pip install phonemizer"
)
@dataclass
class ForceAlignmentInfo(object):
tokens: List[str]
frame_durations: List[int]
start_sec: Optional[float]
end_sec: Optional[float]
def get_mfa_alignment_by_sample_id(
textgrid_zip_path: str, sample_id: str, sample_rate: int,
hop_length: int, silence_phones: List[str] = ("sil", "sp", "spn")
) -> ForceAlignmentInfo:
try:
import tgt
except ImportError:
raise ImportError("Please install TextGridTools: pip install tgt")
filename = f"{sample_id}.TextGrid"
out_root = Path(tempfile.gettempdir())
tgt_path = out_root / filename
with zipfile.ZipFile(textgrid_zip_path) as f_zip:
f_zip.extract(filename, path=out_root)
textgrid = tgt.io.read_textgrid(tgt_path.as_posix())
os.remove(tgt_path)
phones, frame_durations = [], []
start_sec, end_sec, end_idx = 0, 0, 0
for t in textgrid.get_tier_by_name("phones")._objects:
s, e, p = t.start_time, t.end_time, t.text
# Trim leading silences
if len(phones) == 0:
if p in silence_phones:
continue
else:
start_sec = s
phones.append(p)
if p not in silence_phones:
end_sec = e
end_idx = len(phones)
r = sample_rate / hop_length
frame_durations.append(int(np.round(e * r) - np.round(s * r)))
# Trim tailing silences
phones = phones[:end_idx]
frame_durations = frame_durations[:end_idx]
return ForceAlignmentInfo(
tokens=phones, frame_durations=frame_durations, start_sec=start_sec,
end_sec=end_sec
)
def get_mfa_alignment(
textgrid_zip_path: str, sample_ids: List[str], sample_rate: int,
hop_length: int
) -> Dict[str, ForceAlignmentInfo]:
return {
i: get_mfa_alignment_by_sample_id(
textgrid_zip_path, i, sample_rate, hop_length
) for i in tqdm(sample_ids)
}
def get_unit_alignment(
id_to_unit_tsv_path: str, sample_ids: List[str]
) -> Dict[str, ForceAlignmentInfo]:
id_to_units = {
e["id"]: e["units"] for e in load_tsv_to_dicts(id_to_unit_tsv_path)
}
id_to_units = {i: id_to_units[i].split() for i in sample_ids}
id_to_units_collapsed = {
i: [uu for uu, _ in groupby(u)] for i, u in id_to_units.items()
}
id_to_durations = {
i: [len(list(g)) for _, g in groupby(u)] for i, u in id_to_units.items()
}
return {
i: ForceAlignmentInfo(
tokens=id_to_units_collapsed[i], frame_durations=id_to_durations[i],
start_sec=None, end_sec=None
)
for i in sample_ids
}
[[Back]](..)
# Common Voice
[Common Voice](https://commonvoice.mozilla.org/en/datasets) is a public domain speech corpus with 11.2K hours of read
speech in 76 languages (the latest version 7.0). We provide examples for building
[Transformer](https://arxiv.org/abs/1809.08895) models on this dataset.
## Data preparation
[Download](https://commonvoice.mozilla.org/en/datasets) and unpack Common Voice v4 to a path `${DATA_ROOT}/${LANG_ID}`.
Create splits and generate audio manifests with
```bash
python -m examples.speech_synthesis.preprocessing.get_common_voice_audio_manifest \
--data-root ${DATA_ROOT} \
--lang ${LANG_ID} \
--output-manifest-root ${AUDIO_MANIFEST_ROOT} --convert-to-wav
```
Then, extract log-Mel spectrograms, generate feature manifest and create data configuration YAML with
```bash
python -m examples.speech_synthesis.preprocessing.get_feature_manifest \
--audio-manifest-root ${AUDIO_MANIFEST_ROOT} \
--output-root ${FEATURE_MANIFEST_ROOT} \
--ipa-vocab --lang ${LANG_ID}
```
where we use phoneme inputs (`--ipa-vocab`) as example.
To denoise audio and trim leading/trailing silence using signal processing based VAD, run
```bash
for SPLIT in dev test train; do
python -m examples.speech_synthesis.preprocessing.denoise_and_vad_audio \
--audio-manifest ${AUDIO_MANIFEST_ROOT}/${SPLIT}.audio.tsv \
--output-dir ${PROCESSED_DATA_ROOT} \
--denoise --vad --vad-agg-level 2
done
```
## Training
(Please refer to [the LJSpeech example](../docs/ljspeech_example.md#transformer).)
## Inference
(Please refer to [the LJSpeech example](../docs/ljspeech_example.md#inference).)
## Automatic Evaluation
(Please refer to [the LJSpeech example](../docs/ljspeech_example.md#automatic-evaluation).)
## Results
| Language | Speakers | --arch | Params | Test MCD | Model |
|---|---|---|---|---|---|
| English | 200 | tts_transformer | 54M | 3.8 | [Download](https://dl.fbaipublicfiles.com/fairseq/s2/cv4_en200_transformer_phn.tar) |
[[Back]](..)
[[Back]](..)
# LJSpeech
[LJSpeech](https://keithito.com/LJ-Speech-Dataset) is a public domain TTS
corpus with around 24 hours of English speech sampled at 22.05kHz. We provide examples for building
[Transformer](https://arxiv.org/abs/1809.08895) and [FastSpeech 2](https://arxiv.org/abs/2006.04558)
models on this dataset.
## Data preparation
Download data, create splits and generate audio manifests with
```bash
python -m examples.speech_synthesis.preprocessing.get_ljspeech_audio_manifest \
--output-data-root ${AUDIO_DATA_ROOT} \
--output-manifest-root ${AUDIO_MANIFEST_ROOT}
```
Then, extract log-Mel spectrograms, generate feature manifest and create data configuration YAML with
```bash
python -m examples.speech_synthesis.preprocessing.get_feature_manifest \
--audio-manifest-root ${AUDIO_MANIFEST_ROOT} \
--output-root ${FEATURE_MANIFEST_ROOT} \
--ipa-vocab --use-g2p
```
where we use phoneme inputs (`--ipa-vocab --use-g2p`) as example.
FastSpeech 2 additionally requires frame durations, pitch and energy as auxiliary training targets.
Add `--add-fastspeech-targets` to include these fields in the feature manifests. We get frame durations either from
phoneme-level force-alignment or frame-level pseudo-text unit sequence. They should be pre-computed and specified via:
- `--textgrid-zip ${TEXT_GRID_ZIP_PATH}` for a ZIP file, inside which there is one
[TextGrid](https://www.fon.hum.uva.nl/praat/manual/TextGrid.html) file per sample to provide force-alignment info.
- `--id-to-units-tsv ${ID_TO_UNIT_TSV}` for a TSV file, where there are 2 columns for sample ID and
space-delimited pseudo-text unit sequence, respectively.
For your convenience, we provide pre-computed
[force-alignment](https://dl.fbaipublicfiles.com/fairseq/s2/ljspeech_mfa.zip) from
[Montreal Forced Aligner](https://github.com/MontrealCorpusTools/Montreal-Forced-Aligner) and
[pseudo-text units](s3://dl.fbaipublicfiles.com/fairseq/s2/ljspeech_hubert.tsv) from
[HuBERT](https://github.com/pytorch/fairseq/tree/main/examples/hubert). You can also generate them by yourself using
a different software or model.
## Training
#### Transformer
```bash
fairseq-train ${FEATURE_MANIFEST_ROOT} --save-dir ${SAVE_DIR} \
--config-yaml config.yaml --train-subset train --valid-subset dev \
--num-workers 4 --max-tokens 30000 --max-update 200000 \
--task text_to_speech --criterion tacotron2 --arch tts_transformer \
--clip-norm 5.0 --n-frames-per-step 4 --bce-pos-weight 5.0 \
--dropout 0.1 --attention-dropout 0.1 --activation-dropout 0.1 \
--encoder-normalize-before --decoder-normalize-before \
--optimizer adam --lr 2e-3 --lr-scheduler inverse_sqrt --warmup-updates 4000 \
--seed 1 --update-freq 8 --eval-inference --best-checkpoint-metric mcd_loss
```
where `SAVE_DIR` is the checkpoint root path. We set `--update-freq 8` to simulate 8 GPUs with 1 GPU. You may want to
update it accordingly when using more than 1 GPU.
#### FastSpeech2
```bash
fairseq-train ${FEATURE_MANIFEST_ROOT} --save-dir ${SAVE_DIR} \
--config-yaml config.yaml --train-subset train --valid-subset dev \
--num-workers 4 --max-sentences 6 --max-update 200000 \
--task text_to_speech --criterion fastspeech2 --arch fastspeech2 \
--clip-norm 5.0 --n-frames-per-step 1 \
--dropout 0.1 --attention-dropout 0.1 --activation-dropout 0.1 \
--encoder-normalize-before --decoder-normalize-before \
--optimizer adam --lr 5e-4 --lr-scheduler inverse_sqrt --warmup-updates 4000 \
--seed 1 --update-freq 8 --eval-inference --best-checkpoint-metric mcd_loss
```
## Inference
Average the last 5 checkpoints, generate the test split spectrogram and waveform using the default Griffin-Lim vocoder:
```bash
SPLIT=test
CHECKPOINT_NAME=avg_last_5
CHECKPOINT_PATH=${SAVE_DIR}/checkpoint_${CHECKPOINT_NAME}.pt
python scripts/average_checkpoints.py --inputs ${SAVE_DIR} \
--num-epoch-checkpoints 5 \
--output ${CHECKPOINT_PATH}
python -m examples.speech_synthesis.generate_waveform ${FEATURE_MANIFEST_ROOT} \
--config-yaml config.yaml --gen-subset ${SPLIT} --task text_to_speech \
--path ${CHECKPOINT_PATH} --max-tokens 50000 --spec-bwd-max-iter 32 \
--dump-waveforms
```
which dumps files (waveform, feature, attention plot, etc.) to `${SAVE_DIR}/generate-${CHECKPOINT_NAME}-${SPLIT}`. To
re-synthesize target waveforms for automatic evaluation, add `--dump-target`.
## Automatic Evaluation
To start with, generate the manifest for synthetic speech, which will be taken as inputs by evaluation scripts.
```bash
python -m examples.speech_synthesis.evaluation.get_eval_manifest \
--generation-root ${SAVE_DIR}/generate-${CHECKPOINT_NAME}-${SPLIT} \
--audio-manifest ${AUDIO_MANIFEST_ROOT}/${SPLIT}.audio.tsv \
--output-path ${EVAL_OUTPUT_ROOT}/eval.tsv \
--vocoder griffin_lim --sample-rate 22050 --audio-format flac \
--use-resynthesized-target
```
Speech recognition (ASR) models usually operate at lower sample rates (e.g. 16kHz). For the WER/CER metric,
you may need to resample the audios accordingly --- add `--output-sample-rate 16000` for `generate_waveform.py` and
use `--sample-rate 16000` for `get_eval_manifest.py`.
#### WER/CER metric
We use wav2vec 2.0 ASR model as example. [Download](https://github.com/pytorch/fairseq/tree/main/examples/wav2vec)
the model checkpoint and dictionary, then compute WER/CER with
```bash
python -m examples.speech_synthesis.evaluation.eval_asr \
--audio-header syn --text-header text --err-unit char --split ${SPLIT} \
--w2v-ckpt ${WAV2VEC2_CHECKPOINT_PATH} --w2v-dict-dir ${WAV2VEC2_DICT_DIR} \
--raw-manifest ${EVAL_OUTPUT_ROOT}/eval_16khz.tsv --asr-dir ${EVAL_OUTPUT_ROOT}/asr
```
#### MCD/MSD metric
```bash
python -m examples.speech_synthesis.evaluation.eval_sp \
${EVAL_OUTPUT_ROOT}/eval.tsv --mcd --msd
```
#### F0 metrics
```bash
python -m examples.speech_synthesis.evaluation.eval_f0 \
${EVAL_OUTPUT_ROOT}/eval.tsv --gpe --vde --ffe
```
## Results
| --arch | Params | Test MCD | Model |
|---|---|---|---|
| tts_transformer | 54M | 3.8 | [Download](https://dl.fbaipublicfiles.com/fairseq/s2/ljspeech_transformer_phn.tar) |
| fastspeech2 | 41M | 3.8 | [Download](https://dl.fbaipublicfiles.com/fairseq/s2/ljspeech_fastspeech2_phn.tar) |
[[Back]](..)
[[Back]](..)
# VCTK
[VCTK](https://datashare.ed.ac.uk/handle/10283/3443) is an open English speech corpus. We provide examples
for building [Transformer](https://arxiv.org/abs/1809.08895) models on this dataset.
## Data preparation
Download data, create splits and generate audio manifests with
```bash
python -m examples.speech_synthesis.preprocessing.get_vctk_audio_manifest \
--output-data-root ${AUDIO_DATA_ROOT} \
--output-manifest-root ${AUDIO_MANIFEST_ROOT}
```
Then, extract log-Mel spectrograms, generate feature manifest and create data configuration YAML with
```bash
python -m examples.speech_synthesis.preprocessing.get_feature_manifest \
--audio-manifest-root ${AUDIO_MANIFEST_ROOT} \
--output-root ${FEATURE_MANIFEST_ROOT} \
--ipa-vocab --use-g2p
```
where we use phoneme inputs (`--ipa-vocab --use-g2p`) as example.
To denoise audio and trim leading/trailing silence using signal processing based VAD, run
```bash
for SPLIT in dev test train; do
python -m examples.speech_synthesis.preprocessing.denoise_and_vad_audio \
--audio-manifest ${AUDIO_MANIFEST_ROOT}/${SPLIT}.audio.tsv \
--output-dir ${PROCESSED_DATA_ROOT} \
--denoise --vad --vad-agg-level 3
done
```
## Training
(Please refer to [the LJSpeech example](../docs/ljspeech_example.md#transformer).)
## Inference
(Please refer to [the LJSpeech example](../docs/ljspeech_example.md#inference).)
## Automatic Evaluation
(Please refer to [the LJSpeech example](../docs/ljspeech_example.md#automatic-evaluation).)
## Results
| --arch | Params | Test MCD | Model |
|---|---|---|---|
| tts_transformer | 54M | 3.4 | [Download](https://dl.fbaipublicfiles.com/fairseq/s2/vctk_transformer_phn.tar) |
[[Back]](..)
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import editdistance
import re
import shutil
import soundfile as sf
import subprocess
from pathlib import Path
from examples.speech_to_text.data_utils import load_tsv_to_dicts
def preprocess_text(text):
text = "|".join(re.sub(r"[^A-Z' ]", " ", text.upper()).split())
text = " ".join(text)
return text
def prepare_w2v_data(
dict_dir, sample_rate, label, audio_paths, texts, split, data_dir
):
data_dir.mkdir(parents=True, exist_ok=True)
shutil.copyfile(
dict_dir / f"dict.{label}.txt",
data_dir / f"dict.{label}.txt"
)
with open(data_dir / f"{split}.tsv", "w") as f:
f.write("/\n")
for audio_path in audio_paths:
wav, sr = sf.read(audio_path)
assert sr == sample_rate, f"{sr} != sample_rate"
nsample = len(wav)
f.write(f"{audio_path}\t{nsample}\n")
with open(data_dir / f"{split}.{label}", "w") as f:
for text in texts:
text = preprocess_text(text)
f.write(f"{text}\n")
def run_asr(asr_dir, split, w2v_ckpt, w2v_label, res_dir):
"""
results will be saved at
{res_dir}/{ref,hypo}.word-{w2v_ckpt.filename}-{split}.txt
"""
cmd = ["python", "-m", "examples.speech_recognition.infer"]
cmd += [str(asr_dir.resolve())]
cmd += ["--task", "audio_finetuning", "--nbest", "1", "--quiet"]
cmd += ["--w2l-decoder", "viterbi", "--criterion", "ctc"]
cmd += ["--post-process", "letter", "--max-tokens", "4000000"]
cmd += ["--path", str(w2v_ckpt.resolve()), "--labels", w2v_label]
cmd += ["--gen-subset", split, "--results-path", str(res_dir.resolve())]
print(f"running cmd:\n{' '.join(cmd)}")
subprocess.run(cmd, check=True)
def compute_error_rate(hyp_wrd_path, ref_wrd_path, unit="word"):
"""each line is "<text> (None-<index>)" """
tokenize_line = {
"word": lambda x: re.sub(r" \(.*\)$", "", x.rstrip()).split(),
"char": lambda x: list(re.sub(r" \(.*\)$", "", x.rstrip()))
}.get(unit)
if tokenize_line is None:
raise ValueError(f"{unit} not supported")
inds = [int(re.sub(r"\D*(\d*)\D*", r"\1", line))
for line in open(hyp_wrd_path)]
hyps = [tokenize_line(line) for line in open(hyp_wrd_path)]
refs = [tokenize_line(line) for line in open(ref_wrd_path)]
assert(len(hyps) == len(refs))
err_rates = [
editdistance.eval(hyp, ref) / len(ref) for hyp, ref in zip(hyps, refs)
]
ind_to_err_rates = {i: e for i, e in zip(inds, err_rates)}
return ind_to_err_rates
def main(args):
samples = load_tsv_to_dicts(args.raw_manifest)
ids = [
sample[args.id_header] if args.id_header else "" for sample in samples
]
audio_paths = [sample[args.audio_header] for sample in samples]
texts = [sample[args.text_header] for sample in samples]
prepare_w2v_data(
args.w2v_dict_dir,
args.w2v_sample_rate,
args.w2v_label,
audio_paths,
texts,
args.split,
args.asr_dir
)
run_asr(args.asr_dir, args.split, args.w2v_ckpt, args.w2v_label, args.asr_dir)
ind_to_err_rates = compute_error_rate(
args.asr_dir / f"hypo.word-{args.w2v_ckpt.name}-{args.split}.txt",
args.asr_dir / f"ref.word-{args.w2v_ckpt.name}-{args.split}.txt",
args.err_unit,
)
uer_path = args.asr_dir / f"uer_{args.err_unit}.{args.split}.tsv"
with open(uer_path, "w") as f:
f.write("id\taudio\tuer\n")
for ind, (id_, audio_path) in enumerate(zip(ids, audio_paths)):
f.write(f"{id_}\t{audio_path}\t{ind_to_err_rates[ind]:.4f}\n")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--raw-manifest", required=True, type=Path)
parser.add_argument("--asr-dir", required=True, type=Path)
parser.add_argument("--id-header", default="id", type=str)
parser.add_argument("--audio-header", default="audio", type=str)
parser.add_argument("--text-header", default="src_text", type=str)
parser.add_argument("--split", default="raw", type=str)
parser.add_argument("--w2v-ckpt", required=True, type=Path)
parser.add_argument("--w2v-dict-dir", required=True, type=Path)
parser.add_argument("--w2v-sample-rate", default=16000, type=int)
parser.add_argument("--w2v-label", default="ltr", type=str)
parser.add_argument("--err-unit", default="word", type=str)
args = parser.parse_args()
main(args)
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
Signal processing-based evaluation using waveforms
"""
import numpy as np
import os.path as op
import torchaudio
import tqdm
from tabulate import tabulate
from examples.speech_synthesis.utils import (
gross_pitch_error, voicing_decision_error, f0_frame_error
)
from examples.speech_synthesis.evaluation.eval_sp import load_eval_spec
def difference_function(x, n, tau_max):
"""
Compute difference function of data x. This solution is implemented directly
with Numpy fft.
:param x: audio data
:param n: length of data
:param tau_max: integration window size
:return: difference function
:rtype: list
"""
x = np.array(x, np.float64)
w = x.size
tau_max = min(tau_max, w)
x_cumsum = np.concatenate((np.array([0.]), (x * x).cumsum()))
size = w + tau_max
p2 = (size // 32).bit_length()
nice_numbers = (16, 18, 20, 24, 25, 27, 30, 32)
size_pad = min(x * 2 ** p2 for x in nice_numbers if x * 2 ** p2 >= size)
fc = np.fft.rfft(x, size_pad)
conv = np.fft.irfft(fc * fc.conjugate())[:tau_max]
return x_cumsum[w:w - tau_max:-1] + x_cumsum[w] - x_cumsum[:tau_max] - \
2 * conv
def cumulative_mean_normalized_difference_function(df, n):
"""
Compute cumulative mean normalized difference function (CMND).
:param df: Difference function
:param n: length of data
:return: cumulative mean normalized difference function
:rtype: list
"""
# scipy method
cmn_df = df[1:] * range(1, n) / np.cumsum(df[1:]).astype(float)
return np.insert(cmn_df, 0, 1)
def get_pitch(cmdf, tau_min, tau_max, harmo_th=0.1):
"""
Return fundamental period of a frame based on CMND function.
:param cmdf: Cumulative Mean Normalized Difference function
:param tau_min: minimum period for speech
:param tau_max: maximum period for speech
:param harmo_th: harmonicity threshold to determine if it is necessary to
compute pitch frequency
:return: fundamental period if there is values under threshold, 0 otherwise
:rtype: float
"""
tau = tau_min
while tau < tau_max:
if cmdf[tau] < harmo_th:
while tau + 1 < tau_max and cmdf[tau + 1] < cmdf[tau]:
tau += 1
return tau
tau += 1
return 0 # if unvoiced
def compute_yin(sig, sr, w_len=512, w_step=256, f0_min=100, f0_max=500,
harmo_thresh=0.1):
"""
Compute the Yin Algorithm. Return fundamental frequency and harmonic rate.
https://github.com/NVIDIA/mellotron adaption of
https://github.com/patriceguyot/Yin
:param sig: Audio signal (list of float)
:param sr: sampling rate (int)
:param w_len: size of the analysis window (samples)
:param w_step: size of the lag between two consecutives windows (samples)
:param f0_min: Minimum fundamental frequency that can be detected (hertz)
:param f0_max: Maximum fundamental frequency that can be detected (hertz)
:param harmo_thresh: Threshold of detection. The yalgorithmù return the
first minimum of the CMND function below this threshold.
:returns:
* pitches: list of fundamental frequencies,
* harmonic_rates: list of harmonic rate values for each fundamental
frequency value (= confidence value)
* argmins: minimums of the Cumulative Mean Normalized DifferenceFunction
* times: list of time of each estimation
:rtype: tuple
"""
tau_min = int(sr / f0_max)
tau_max = int(sr / f0_min)
# time values for each analysis window
time_scale = range(0, len(sig) - w_len, w_step)
times = [t/float(sr) for t in time_scale]
frames = [sig[t:t + w_len] for t in time_scale]
pitches = [0.0] * len(time_scale)
harmonic_rates = [0.0] * len(time_scale)
argmins = [0.0] * len(time_scale)
for i, frame in enumerate(frames):
# Compute YIN
df = difference_function(frame, w_len, tau_max)
cm_df = cumulative_mean_normalized_difference_function(df, tau_max)
p = get_pitch(cm_df, tau_min, tau_max, harmo_thresh)
# Get results
if np.argmin(cm_df) > tau_min:
argmins[i] = float(sr / np.argmin(cm_df))
if p != 0: # A pitch was found
pitches[i] = float(sr / p)
harmonic_rates[i] = cm_df[p]
else: # No pitch, but we compute a value of the harmonic rate
harmonic_rates[i] = min(cm_df)
return pitches, harmonic_rates, argmins, times
def extract_f0(samples):
f0_samples = []
for sample in tqdm.tqdm(samples):
if not op.isfile(sample["ref"]) or not op.isfile(sample["syn"]):
f0_samples.append(None)
continue
# assume single channel
yref, sr = torchaudio.load(sample["ref"])
ysyn, _sr = torchaudio.load(sample["syn"])
yref, ysyn = yref[0], ysyn[0]
assert sr == _sr, f"{sr} != {_sr}"
yref_f0 = compute_yin(yref, sr)
ysyn_f0 = compute_yin(ysyn, sr)
f0_samples += [
{
"ref": yref_f0,
"syn": ysyn_f0
}
]
return f0_samples
def eval_f0_error(samples, distortion_fn):
results = []
for sample in tqdm.tqdm(samples):
if sample is None:
results.append(None)
continue
# assume single channel
yref_f, _, _, yref_t = sample["ref"]
ysyn_f, _, _, ysyn_t = sample["syn"]
yref_f = np.array(yref_f)
yref_t = np.array(yref_t)
ysyn_f = np.array(ysyn_f)
ysyn_t = np.array(ysyn_t)
distortion = distortion_fn(yref_t, yref_f, ysyn_t, ysyn_f)
results.append((distortion.item(),
len(yref_f),
len(ysyn_f)
))
return results
def eval_gross_pitch_error(samples):
return eval_f0_error(samples, gross_pitch_error)
def eval_voicing_decision_error(samples):
return eval_f0_error(samples, voicing_decision_error)
def eval_f0_frame_error(samples):
return eval_f0_error(samples, f0_frame_error)
def print_results(results, show_bin):
results = np.array(list(filter(lambda x: x is not None, results)))
np.set_printoptions(precision=3)
def _print_result(results):
res = {
"nutt": len(results),
"error": results[:, 0].mean(),
"std": results[:, 0].std(),
"dur_ref": int(results[:, 1].sum()),
"dur_syn": int(results[:, 2].sum()),
}
print(tabulate([res.values()], res.keys(), floatfmt=".4f"))
print(">>>> ALL")
_print_result(results)
if show_bin:
edges = [0, 200, 400, 600, 800, 1000, 2000, 4000]
for i in range(1, len(edges)):
mask = np.logical_and(results[:, 1] >= edges[i-1],
results[:, 1] < edges[i])
if not mask.any():
continue
bin_results = results[mask]
print(f">>>> ({edges[i-1]}, {edges[i]})")
_print_result(bin_results)
def main(eval_f0, gpe, vde, ffe, show_bin):
samples = load_eval_spec(eval_f0)
if gpe or vde or ffe:
f0_samples = extract_f0(samples)
if gpe:
print("===== Evaluate Gross Pitch Error =====")
results = eval_gross_pitch_error(f0_samples)
print_results(results, show_bin)
if vde:
print("===== Evaluate Voicing Decision Error =====")
results = eval_voicing_decision_error(f0_samples)
print_results(results, show_bin)
if ffe:
print("===== Evaluate F0 Frame Error =====")
results = eval_f0_frame_error(f0_samples)
print_results(results, show_bin)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("eval_f0")
parser.add_argument("--gpe", action="store_true")
parser.add_argument("--vde", action="store_true")
parser.add_argument("--ffe", action="store_true")
parser.add_argument("--show-bin", action="store_true")
args = parser.parse_args()
main(args.eval_f0, args.gpe, args.vde, args.ffe, args.show_bin)
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
Signal processing-based evaluation using waveforms
"""
import csv
import numpy as np
import os.path as op
import torch
import tqdm
from tabulate import tabulate
import torchaudio
from examples.speech_synthesis.utils import batch_mel_spectral_distortion
from fairseq.tasks.text_to_speech import batch_mel_cepstral_distortion
def load_eval_spec(path):
with open(path) as f:
reader = csv.DictReader(f, delimiter='\t')
samples = list(reader)
return samples
def eval_distortion(samples, distortion_fn, device="cuda"):
nmiss = 0
results = []
for sample in tqdm.tqdm(samples):
if not op.isfile(sample["ref"]) or not op.isfile(sample["syn"]):
nmiss += 1
results.append(None)
continue
# assume single channel
yref, sr = torchaudio.load(sample["ref"])
ysyn, _sr = torchaudio.load(sample["syn"])
yref, ysyn = yref[0].to(device), ysyn[0].to(device)
assert sr == _sr, f"{sr} != {_sr}"
distortion, extra = distortion_fn([yref], [ysyn], sr, None)[0]
_, _, _, _, _, pathmap = extra
nins = torch.sum(pathmap.sum(dim=1) - 1) # extra frames in syn
ndel = torch.sum(pathmap.sum(dim=0) - 1) # missing frames from syn
results.append(
(distortion.item(), # path distortion
pathmap.size(0), # yref num frames
pathmap.size(1), # ysyn num frames
pathmap.sum().item(), # path length
nins.item(), # insertion
ndel.item(), # deletion
)
)
return results
def eval_mel_cepstral_distortion(samples, device="cuda"):
return eval_distortion(samples, batch_mel_cepstral_distortion, device)
def eval_mel_spectral_distortion(samples, device="cuda"):
return eval_distortion(samples, batch_mel_spectral_distortion, device)
def print_results(results, show_bin):
results = np.array(list(filter(lambda x: x is not None, results)))
np.set_printoptions(precision=3)
def _print_result(results):
dist, dur_ref, dur_syn, dur_ali, nins, ndel = results.sum(axis=0)
res = {
"nutt": len(results),
"dist": dist,
"dur_ref": int(dur_ref),
"dur_syn": int(dur_syn),
"dur_ali": int(dur_ali),
"dist_per_ref_frm": dist/dur_ref,
"dist_per_syn_frm": dist/dur_syn,
"dist_per_ali_frm": dist/dur_ali,
"ins": nins/dur_ref,
"del": ndel/dur_ref,
}
print(tabulate(
[res.values()],
res.keys(),
floatfmt=".4f"
))
print(">>>> ALL")
_print_result(results)
if show_bin:
edges = [0, 200, 400, 600, 800, 1000, 2000, 4000]
for i in range(1, len(edges)):
mask = np.logical_and(results[:, 1] >= edges[i-1],
results[:, 1] < edges[i])
if not mask.any():
continue
bin_results = results[mask]
print(f">>>> ({edges[i-1]}, {edges[i]})")
_print_result(bin_results)
def main(eval_spec, mcd, msd, show_bin):
samples = load_eval_spec(eval_spec)
device = "cpu"
if mcd:
print("===== Evaluate Mean Cepstral Distortion =====")
results = eval_mel_cepstral_distortion(samples, device)
print_results(results, show_bin)
if msd:
print("===== Evaluate Mean Spectral Distortion =====")
results = eval_mel_spectral_distortion(samples, device)
print_results(results, show_bin)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("eval_spec")
parser.add_argument("--mcd", action="store_true")
parser.add_argument("--msd", action="store_true")
parser.add_argument("--show-bin", action="store_true")
args = parser.parse_args()
main(args.eval_spec, args.mcd, args.msd, args.show_bin)
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import csv
from pathlib import Path
def main(args):
"""
`uid syn ref text`
"""
in_root = Path(args.generation_root).resolve()
ext = args.audio_format
with open(args.audio_manifest) as f, open(args.output_path, "w") as f_out:
reader = csv.DictReader(
f, delimiter="\t", quotechar=None, doublequote=False,
lineterminator="\n", quoting=csv.QUOTE_NONE
)
header = ["id", "syn", "ref", "text", "speaker"]
f_out.write("\t".join(header) + "\n")
for row in reader:
dir_name = f"{ext}_{args.sample_rate}hz_{args.vocoder}"
id_ = row["id"]
syn = (in_root / dir_name / f"{id_}.{ext}").as_posix()
ref = row["audio"]
if args.use_resynthesized_target:
ref = (in_root / f"{dir_name}_tgt" / f"{id_}.{ext}").as_posix()
sample = [id_, syn, ref, row["tgt_text"], row["speaker"]]
f_out.write("\t".join(sample) + "\n")
print(f"wrote evaluation file to {args.output_path}")
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument(
"--generation-root", help="output directory for generate_waveform.py"
)
parser.add_argument(
"--audio-manifest",
help="used to determine the original utterance ID and text"
)
parser.add_argument(
"--output-path", help="path to output evaluation spec file"
)
parser.add_argument(
"--use-resynthesized-target", action="store_true",
help="use resynthesized reference instead of the original audio"
)
parser.add_argument("--vocoder", type=str, default="griffin_lim")
parser.add_argument("--sample-rate", type=int, default=22_050)
parser.add_argument("--audio-format", type=str, default="wav")
args = parser.parse_args()
main(args)
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path
import soundfile as sf
import sys
import torch
import torchaudio
from fairseq import checkpoint_utils, options, tasks, utils
from fairseq.logging import progress_bar
from fairseq.tasks.text_to_speech import plot_tts_output
from fairseq.data.audio.text_to_speech_dataset import TextToSpeechDataset
logging.basicConfig()
logging.root.setLevel(logging.INFO)
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def make_parser():
parser = options.get_speech_generation_parser()
parser.add_argument("--dump-features", action="store_true")
parser.add_argument("--dump-waveforms", action="store_true")
parser.add_argument("--dump-attentions", action="store_true")
parser.add_argument("--dump-eos-probs", action="store_true")
parser.add_argument("--dump-plots", action="store_true")
parser.add_argument("--dump-target", action="store_true")
parser.add_argument("--output-sample-rate", default=22050, type=int)
parser.add_argument("--teacher-forcing", action="store_true")
parser.add_argument(
"--audio-format", type=str, default="wav", choices=["wav", "flac"]
)
return parser
def postprocess_results(
dataset: TextToSpeechDataset, sample, hypos, resample_fn, dump_target
):
def to_np(x):
return None if x is None else x.detach().cpu().numpy()
sample_ids = [dataset.ids[i] for i in sample["id"].tolist()]
texts = sample["src_texts"]
attns = [to_np(hypo["attn"]) for hypo in hypos]
eos_probs = [to_np(hypo.get("eos_prob", None)) for hypo in hypos]
feat_preds = [to_np(hypo["feature"]) for hypo in hypos]
wave_preds = [to_np(resample_fn(h["waveform"])) for h in hypos]
if dump_target:
feat_targs = [to_np(hypo["targ_feature"]) for hypo in hypos]
wave_targs = [to_np(resample_fn(h["targ_waveform"])) for h in hypos]
else:
feat_targs = [None for _ in hypos]
wave_targs = [None for _ in hypos]
return zip(sample_ids, texts, attns, eos_probs, feat_preds, wave_preds,
feat_targs, wave_targs)
def dump_result(
is_na_model,
args,
vocoder,
sample_id,
text,
attn,
eos_prob,
feat_pred,
wave_pred,
feat_targ,
wave_targ,
):
sample_rate = args.output_sample_rate
out_root = Path(args.results_path)
if args.dump_features:
feat_dir = out_root / "feat"
feat_dir.mkdir(exist_ok=True, parents=True)
np.save(feat_dir / f"{sample_id}.npy", feat_pred)
if args.dump_target:
feat_tgt_dir = out_root / "feat_tgt"
feat_tgt_dir.mkdir(exist_ok=True, parents=True)
np.save(feat_tgt_dir / f"{sample_id}.npy", feat_targ)
if args.dump_attentions:
attn_dir = out_root / "attn"
attn_dir.mkdir(exist_ok=True, parents=True)
np.save(attn_dir / f"{sample_id}.npy", attn.numpy())
if args.dump_eos_probs and not is_na_model:
eos_dir = out_root / "eos"
eos_dir.mkdir(exist_ok=True, parents=True)
np.save(eos_dir / f"{sample_id}.npy", eos_prob)
if args.dump_plots:
images = [feat_pred.T] if is_na_model else [feat_pred.T, attn]
names = ["output"] if is_na_model else ["output", "alignment"]
if feat_targ is not None:
images = [feat_targ.T] + images
names = [f"target (idx={sample_id})"] + names
if is_na_model:
plot_tts_output(images, names, attn, "alignment", suptitle=text)
else:
plot_tts_output(images, names, eos_prob, "eos prob", suptitle=text)
plot_dir = out_root / "plot"
plot_dir.mkdir(exist_ok=True, parents=True)
plt.savefig(plot_dir / f"{sample_id}.png")
plt.close()
if args.dump_waveforms:
ext = args.audio_format
if wave_pred is not None:
wav_dir = out_root / f"{ext}_{sample_rate}hz_{vocoder}"
wav_dir.mkdir(exist_ok=True, parents=True)
sf.write(wav_dir / f"{sample_id}.{ext}", wave_pred, sample_rate)
if args.dump_target and wave_targ is not None:
wav_tgt_dir = out_root / f"{ext}_{sample_rate}hz_{vocoder}_tgt"
wav_tgt_dir.mkdir(exist_ok=True, parents=True)
sf.write(wav_tgt_dir / f"{sample_id}.{ext}", wave_targ, sample_rate)
def main(args):
assert(args.dump_features or args.dump_waveforms or args.dump_attentions
or args.dump_eos_probs or args.dump_plots)
if args.max_tokens is None and args.batch_size is None:
args.max_tokens = 8000
logger.info(args)
use_cuda = torch.cuda.is_available() and not args.cpu
task = tasks.setup_task(args)
models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task(
[args.path],
task=task,
)
model = models[0].cuda() if use_cuda else models[0]
# use the original n_frames_per_step
task.args.n_frames_per_step = saved_cfg.task.n_frames_per_step
task.load_dataset(args.gen_subset, task_cfg=saved_cfg.task)
data_cfg = task.data_cfg
sample_rate = data_cfg.config.get("features", {}).get("sample_rate", 22050)
resample_fn = {
False: lambda x: x,
True: lambda x: torchaudio.sox_effects.apply_effects_tensor(
x.detach().cpu().unsqueeze(0), sample_rate,
[['rate', str(args.output_sample_rate)]]
)[0].squeeze(0)
}.get(args.output_sample_rate != sample_rate)
if args.output_sample_rate != sample_rate:
logger.info(f"resampling to {args.output_sample_rate}Hz")
generator = task.build_generator([model], args)
itr = task.get_batch_iterator(
dataset=task.dataset(args.gen_subset),
max_tokens=args.max_tokens,
max_sentences=args.batch_size,
max_positions=(sys.maxsize, sys.maxsize),
ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
required_batch_size_multiple=args.required_batch_size_multiple,
num_shards=args.num_shards,
shard_id=args.shard_id,
num_workers=args.num_workers,
data_buffer_size=args.data_buffer_size,
).next_epoch_itr(shuffle=False)
Path(args.results_path).mkdir(exist_ok=True, parents=True)
is_na_model = getattr(model, "NON_AUTOREGRESSIVE", False)
dataset = task.dataset(args.gen_subset)
vocoder = task.args.vocoder
with progress_bar.build_progress_bar(args, itr) as t:
for sample in t:
sample = utils.move_to_cuda(sample) if use_cuda else sample
hypos = generator.generate(model, sample, has_targ=args.dump_target)
for result in postprocess_results(
dataset, sample, hypos, resample_fn, args.dump_target
):
dump_result(is_na_model, args, vocoder, *result)
def cli_main():
parser = make_parser()
args = options.parse_args_and_arch(parser)
main(args)
if __name__ == "__main__":
cli_main()
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import logging
import os
import csv
import tempfile
from collections import defaultdict
from pathlib import Path
import torchaudio
try:
import webrtcvad
except ImportError:
raise ImportError("Please install py-webrtcvad: pip install webrtcvad")
import pandas as pd
from tqdm import tqdm
from examples.speech_synthesis.preprocessing.denoiser.pretrained import master64
import examples.speech_synthesis.preprocessing.denoiser.utils as utils
from examples.speech_synthesis.preprocessing.vad import (
frame_generator, vad_collector, read_wave, write_wave, FS_MS, THRESHOLD,
SCALE
)
from examples.speech_to_text.data_utils import save_df_to_tsv
log = logging.getLogger(__name__)
PATHS = ["after_denoise", "after_vad"]
MIN_T = 0.05
def generate_tmp_filename(extension="txt"):
return tempfile._get_default_tempdir() + "/" + \
next(tempfile._get_candidate_names()) + "." + extension
def convert_sr(inpath, sr, output_path=None):
if not output_path:
output_path = generate_tmp_filename("wav")
cmd = f"sox {inpath} -r {sr} {output_path}"
os.system(cmd)
return output_path
def apply_vad(vad, inpath):
audio, sample_rate = read_wave(inpath)
frames = frame_generator(FS_MS, audio, sample_rate)
frames = list(frames)
segments = vad_collector(sample_rate, FS_MS, 300, vad, frames)
merge_segments = list()
timestamp_start = 0.0
timestamp_end = 0.0
# removing start, end, and long sequences of sils
for i, segment in enumerate(segments):
merge_segments.append(segment[0])
if i and timestamp_start:
sil_duration = segment[1] - timestamp_end
if sil_duration > THRESHOLD:
merge_segments.append(int(THRESHOLD / SCALE) * (b'\x00'))
else:
merge_segments.append(int((sil_duration / SCALE)) * (b'\x00'))
timestamp_start = segment[1]
timestamp_end = segment[2]
segment = b''.join(merge_segments)
return segment, sample_rate
def write(wav, filename, sr=16_000):
# Normalize audio if it prevents clipping
wav = wav / max(wav.abs().max().item(), 1)
torchaudio.save(filename, wav.cpu(), sr, encoding="PCM_S",
bits_per_sample=16)
def process(args):
# making sure we are requested either denoise or vad
if not args.denoise and not args.vad:
log.error("No denoise or vad is requested.")
return
log.info("Creating out directories...")
if args.denoise:
out_denoise = Path(args.output_dir).absolute().joinpath(PATHS[0])
out_denoise.mkdir(parents=True, exist_ok=True)
if args.vad:
out_vad = Path(args.output_dir).absolute().joinpath(PATHS[1])
out_vad.mkdir(parents=True, exist_ok=True)
log.info("Loading pre-trained speech enhancement model...")
model = master64().to(args.device)
log.info("Building the VAD model...")
vad = webrtcvad.Vad(int(args.vad_agg_level))
# preparing the output dict
output_dict = defaultdict(list)
log.info(f"Parsing input manifest: {args.audio_manifest}")
with open(args.audio_manifest, "r") as f:
manifest_dict = csv.DictReader(f, delimiter="\t")
for row in tqdm(manifest_dict):
filename = str(row["audio"])
final_output = filename
keep_sample = True
n_frames = row["n_frames"]
snr = -1
if args.denoise:
output_path_denoise = out_denoise.joinpath(Path(filename).name)
# convert to 16khz in case we use a differet sr
tmp_path = convert_sr(final_output, 16000)
# loading audio file and generating the enhanced version
out, sr = torchaudio.load(tmp_path)
out = out.to(args.device)
estimate = model(out)
estimate = (1 - args.dry_wet) * estimate + args.dry_wet * out
write(estimate[0], str(output_path_denoise), sr)
snr = utils.cal_snr(out, estimate)
snr = snr.cpu().detach().numpy()[0][0]
final_output = str(output_path_denoise)
if args.vad:
output_path_vad = out_vad.joinpath(Path(filename).name)
sr = torchaudio.info(final_output).sample_rate
if sr in [16000, 32000, 48000]:
tmp_path = final_output
elif sr < 16000:
tmp_path = convert_sr(final_output, 16000)
elif sr < 32000:
tmp_path = convert_sr(final_output, 32000)
else:
tmp_path = convert_sr(final_output, 48000)
# apply VAD
segment, sample_rate = apply_vad(vad, tmp_path)
if len(segment) < sample_rate * MIN_T:
keep_sample = False
print((
f"WARNING: skip {filename} because it is too short "
f"after VAD ({len(segment) / sample_rate} < {MIN_T})"
))
else:
if sample_rate != sr:
tmp_path = generate_tmp_filename("wav")
write_wave(tmp_path, segment, sample_rate)
convert_sr(tmp_path, sr,
output_path=str(output_path_vad))
else:
write_wave(str(output_path_vad), segment, sample_rate)
final_output = str(output_path_vad)
segment, _ = torchaudio.load(final_output)
n_frames = segment.size(1)
if keep_sample:
output_dict["id"].append(row["id"])
output_dict["audio"].append(final_output)
output_dict["n_frames"].append(n_frames)
output_dict["tgt_text"].append(row["tgt_text"])
output_dict["speaker"].append(row["speaker"])
output_dict["src_text"].append(row["src_text"])
output_dict["snr"].append(snr)
out_tsv_path = Path(args.output_dir) / Path(args.audio_manifest).name
log.info(f"Saving manifest to {out_tsv_path.as_posix()}")
save_df_to_tsv(pd.DataFrame.from_dict(output_dict), out_tsv_path)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--audio-manifest", "-i", required=True,
type=str, help="path to the input manifest.")
parser.add_argument(
"--output-dir", "-o", required=True, type=str,
help="path to the output dir. it will contain files after denoising and"
" vad"
)
parser.add_argument("--vad-agg-level", "-a", type=int, default=2,
help="the aggresive level of the vad [0-3].")
parser.add_argument(
"--dry-wet", "-dw", type=float, default=0.01,
help="the level of linear interpolation between noisy and enhanced "
"files."
)
parser.add_argument(
"--device", "-d", type=str, default="cpu",
help="the device to be used for the speech enhancement model: "
"cpu | cuda."
)
parser.add_argument("--denoise", action="store_true",
help="apply a denoising")
parser.add_argument("--vad", action="store_true", help="apply a VAD")
args = parser.parse_args()
process(args)
if __name__ == "__main__":
main()
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# author: adefossez
import math
import time
import torch as th
from torch import nn
from torch.nn import functional as F
from .resample import downsample2, upsample2
from .utils import capture_init
class BLSTM(nn.Module):
def __init__(self, dim, layers=2, bi=True):
super().__init__()
klass = nn.LSTM
self.lstm = klass(
bidirectional=bi, num_layers=layers, hidden_size=dim, input_size=dim
)
self.linear = None
if bi:
self.linear = nn.Linear(2 * dim, dim)
def forward(self, x, hidden=None):
x, hidden = self.lstm(x, hidden)
if self.linear:
x = self.linear(x)
return x, hidden
def rescale_conv(conv, reference):
std = conv.weight.std().detach()
scale = (std / reference)**0.5
conv.weight.data /= scale
if conv.bias is not None:
conv.bias.data /= scale
def rescale_module(module, reference):
for sub in module.modules():
if isinstance(sub, (nn.Conv1d, nn.ConvTranspose1d)):
rescale_conv(sub, reference)
class Demucs(nn.Module):
"""
Demucs speech enhancement model.
Args:
- chin (int): number of input channels.
- chout (int): number of output channels.
- hidden (int): number of initial hidden channels.
- depth (int): number of layers.
- kernel_size (int): kernel size for each layer.
- stride (int): stride for each layer.
- causal (bool): if false, uses BiLSTM instead of LSTM.
- resample (int): amount of resampling to apply to the input/output.
Can be one of 1, 2 or 4.
- growth (float): number of channels is multiplied by this for every layer.
- max_hidden (int): maximum number of channels. Can be useful to
control the size/speed of the model.
- normalize (bool): if true, normalize the input.
- glu (bool): if true uses GLU instead of ReLU in 1x1 convolutions.
- rescale (float): controls custom weight initialization.
See https://arxiv.org/abs/1911.13254.
- floor (float): stability flooring when normalizing.
"""
@capture_init
def __init__(self,
chin=1,
chout=1,
hidden=48,
depth=5,
kernel_size=8,
stride=4,
causal=True,
resample=4,
growth=2,
max_hidden=10_000,
normalize=True,
glu=True,
rescale=0.1,
floor=1e-3):
super().__init__()
if resample not in [1, 2, 4]:
raise ValueError("Resample should be 1, 2 or 4.")
self.chin = chin
self.chout = chout
self.hidden = hidden
self.depth = depth
self.kernel_size = kernel_size
self.stride = stride
self.causal = causal
self.floor = floor
self.resample = resample
self.normalize = normalize
self.encoder = nn.ModuleList()
self.decoder = nn.ModuleList()
activation = nn.GLU(1) if glu else nn.ReLU()
ch_scale = 2 if glu else 1
for index in range(depth):
encode = []
encode += [
nn.Conv1d(chin, hidden, kernel_size, stride),
nn.ReLU(),
nn.Conv1d(hidden, hidden * ch_scale, 1), activation,
]
self.encoder.append(nn.Sequential(*encode))
decode = []
decode += [
nn.Conv1d(hidden, ch_scale * hidden, 1), activation,
nn.ConvTranspose1d(hidden, chout, kernel_size, stride),
]
if index > 0:
decode.append(nn.ReLU())
self.decoder.insert(0, nn.Sequential(*decode))
chout = hidden
chin = hidden
hidden = min(int(growth * hidden), max_hidden)
self.lstm = BLSTM(chin, bi=not causal)
if rescale:
rescale_module(self, reference=rescale)
def valid_length(self, length):
"""
Return the nearest valid length to use with the model so that
there is no time steps left over in a convolutions, e.g. for all
layers, size of the input - kernel_size % stride = 0.
If the mixture has a valid length, the estimated sources
will have exactly the same length.
"""
length = math.ceil(length * self.resample)
for _ in range(self.depth):
length = math.ceil((length - self.kernel_size) / self.stride) + 1
length = max(length, 1)
for _ in range(self.depth):
length = (length - 1) * self.stride + self.kernel_size
length = int(math.ceil(length / self.resample))
return int(length)
@property
def total_stride(self):
return self.stride ** self.depth // self.resample
def forward(self, mix):
if mix.dim() == 2:
mix = mix.unsqueeze(1)
if self.normalize:
mono = mix.mean(dim=1, keepdim=True)
std = mono.std(dim=-1, keepdim=True)
mix = mix / (self.floor + std)
else:
std = 1
length = mix.shape[-1]
x = mix
x = F.pad(x, (0, self.valid_length(length) - length))
if self.resample == 2:
x = upsample2(x)
elif self.resample == 4:
x = upsample2(x)
x = upsample2(x)
skips = []
for encode in self.encoder:
x = encode(x)
skips.append(x)
x = x.permute(2, 0, 1)
x, _ = self.lstm(x)
x = x.permute(1, 2, 0)
for decode in self.decoder:
skip = skips.pop(-1)
x = x + skip[..., :x.shape[-1]]
x = decode(x)
if self.resample == 2:
x = downsample2(x)
elif self.resample == 4:
x = downsample2(x)
x = downsample2(x)
x = x[..., :length]
return std * x
def fast_conv(conv, x):
"""
Faster convolution evaluation if either kernel size is 1
or length of sequence is 1.
"""
batch, chin, length = x.shape
chout, chin, kernel = conv.weight.shape
assert batch == 1
if kernel == 1:
x = x.view(chin, length)
out = th.addmm(conv.bias.view(-1, 1),
conv.weight.view(chout, chin), x)
elif length == kernel:
x = x.view(chin * kernel, 1)
out = th.addmm(conv.bias.view(-1, 1),
conv.weight.view(chout, chin * kernel), x)
else:
out = conv(x)
return out.view(batch, chout, -1)
class DemucsStreamer:
"""
Streaming implementation for Demucs. It supports being fed with any amount
of audio at a time. You will get back as much audio as possible at that
point.
Args:
- demucs (Demucs): Demucs model.
- dry (float): amount of dry (e.g. input) signal to keep. 0 is maximum
noise removal, 1 just returns the input signal. Small values > 0
allows to limit distortions.
- num_frames (int): number of frames to process at once. Higher values
will increase overall latency but improve the real time factor.
- resample_lookahead (int): extra lookahead used for the resampling.
- resample_buffer (int): size of the buffer of previous inputs/outputs
kept for resampling.
"""
def __init__(self, demucs,
dry=0,
num_frames=1,
resample_lookahead=64,
resample_buffer=256):
device = next(iter(demucs.parameters())).device
self.demucs = demucs
self.lstm_state = None
self.conv_state = None
self.dry = dry
self.resample_lookahead = resample_lookahead
resample_buffer = min(demucs.total_stride, resample_buffer)
self.resample_buffer = resample_buffer
self.frame_length = demucs.valid_length(1) + \
demucs.total_stride * (num_frames - 1)
self.total_length = self.frame_length + self.resample_lookahead
self.stride = demucs.total_stride * num_frames
self.resample_in = th.zeros(demucs.chin, resample_buffer, device=device)
self.resample_out = th.zeros(
demucs.chin, resample_buffer, device=device
)
self.frames = 0
self.total_time = 0
self.variance = 0
self.pending = th.zeros(demucs.chin, 0, device=device)
bias = demucs.decoder[0][2].bias
weight = demucs.decoder[0][2].weight
chin, chout, kernel = weight.shape
self._bias = bias.view(-1, 1).repeat(1, kernel).view(-1, 1)
self._weight = weight.permute(1, 2, 0).contiguous()
def reset_time_per_frame(self):
self.total_time = 0
self.frames = 0
@property
def time_per_frame(self):
return self.total_time / self.frames
def flush(self):
"""
Flush remaining audio by padding it with zero. Call this
when you have no more input and want to get back the last chunk of audio.
"""
pending_length = self.pending.shape[1]
padding = th.zeros(
self.demucs.chin, self.total_length, device=self.pending.device
)
out = self.feed(padding)
return out[:, :pending_length]
def feed(self, wav):
"""
Apply the model to mix using true real time evaluation.
Normalization is done online as is the resampling.
"""
begin = time.time()
demucs = self.demucs
resample_buffer = self.resample_buffer
stride = self.stride
resample = demucs.resample
if wav.dim() != 2:
raise ValueError("input wav should be two dimensional.")
chin, _ = wav.shape
if chin != demucs.chin:
raise ValueError(f"Expected {demucs.chin} channels, got {chin}")
self.pending = th.cat([self.pending, wav], dim=1)
outs = []
while self.pending.shape[1] >= self.total_length:
self.frames += 1
frame = self.pending[:, :self.total_length]
dry_signal = frame[:, :stride]
if demucs.normalize:
mono = frame.mean(0)
variance = (mono**2).mean()
self.variance = variance / self.frames + \
(1 - 1 / self.frames) * self.variance
frame = frame / (demucs.floor + math.sqrt(self.variance))
frame = th.cat([self.resample_in, frame], dim=-1)
self.resample_in[:] = frame[:, stride - resample_buffer:stride]
if resample == 4:
frame = upsample2(upsample2(frame))
elif resample == 2:
frame = upsample2(frame)
# remove pre sampling buffer
frame = frame[:, resample * resample_buffer:]
# remove extra samples after window
frame = frame[:, :resample * self.frame_length]
out, extra = self._separate_frame(frame)
padded_out = th.cat([self.resample_out, out, extra], 1)
self.resample_out[:] = out[:, -resample_buffer:]
if resample == 4:
out = downsample2(downsample2(padded_out))
elif resample == 2:
out = downsample2(padded_out)
else:
out = padded_out
out = out[:, resample_buffer // resample:]
out = out[:, :stride]
if demucs.normalize:
out *= math.sqrt(self.variance)
out = self.dry * dry_signal + (1 - self.dry) * out
outs.append(out)
self.pending = self.pending[:, stride:]
self.total_time += time.time() - begin
if outs:
out = th.cat(outs, 1)
else:
out = th.zeros(chin, 0, device=wav.device)
return out
def _separate_frame(self, frame):
demucs = self.demucs
skips = []
next_state = []
first = self.conv_state is None
stride = self.stride * demucs.resample
x = frame[None]
for idx, encode in enumerate(demucs.encoder):
stride //= demucs.stride
length = x.shape[2]
if idx == demucs.depth - 1:
# This is sligthly faster for the last conv
x = fast_conv(encode[0], x)
x = encode[1](x)
x = fast_conv(encode[2], x)
x = encode[3](x)
else:
if not first:
prev = self.conv_state.pop(0)
prev = prev[..., stride:]
tgt = (length - demucs.kernel_size) // demucs.stride + 1
missing = tgt - prev.shape[-1]
offset = length - demucs.kernel_size - \
demucs.stride * (missing - 1)
x = x[..., offset:]
x = encode[1](encode[0](x))
x = fast_conv(encode[2], x)
x = encode[3](x)
if not first:
x = th.cat([prev, x], -1)
next_state.append(x)
skips.append(x)
x = x.permute(2, 0, 1)
x, self.lstm_state = demucs.lstm(x, self.lstm_state)
x = x.permute(1, 2, 0)
# In the following, x contains only correct samples, i.e. the one
# for which each time position is covered by two window of the upper
# layer. extra contains extra samples to the right, and is used only as
# a better padding for the online resampling.
extra = None
for idx, decode in enumerate(demucs.decoder):
skip = skips.pop(-1)
x += skip[..., :x.shape[-1]]
x = fast_conv(decode[0], x)
x = decode[1](x)
if extra is not None:
skip = skip[..., x.shape[-1]:]
extra += skip[..., :extra.shape[-1]]
extra = decode[2](decode[1](decode[0](extra)))
x = decode[2](x)
next_state.append(
x[..., -demucs.stride:] - decode[2].bias.view(-1, 1)
)
if extra is None:
extra = x[..., -demucs.stride:]
else:
extra[..., :demucs.stride] += next_state[-1]
x = x[..., :-demucs.stride]
if not first:
prev = self.conv_state.pop(0)
x[..., :demucs.stride] += prev
if idx != demucs.depth - 1:
x = decode[3](x)
extra = decode[3](extra)
self.conv_state = next_state
return x[0], extra[0]
def test():
import argparse
parser = argparse.ArgumentParser(
"denoiser.demucs",
description="Benchmark the streaming Demucs implementation, as well as "
"checking the delta with the offline implementation.")
parser.add_argument("--depth", default=5, type=int)
parser.add_argument("--resample", default=4, type=int)
parser.add_argument("--hidden", default=48, type=int)
parser.add_argument("--sample_rate", default=16000, type=float)
parser.add_argument("--device", default="cpu")
parser.add_argument("-t", "--num_threads", type=int)
parser.add_argument("-f", "--num_frames", type=int, default=1)
args = parser.parse_args()
if args.num_threads:
th.set_num_threads(args.num_threads)
sr = args.sample_rate
sr_ms = sr / 1000
demucs = Demucs(
depth=args.depth, hidden=args.hidden, resample=args.resample
).to(args.device)
x = th.randn(1, int(sr * 4)).to(args.device)
out = demucs(x[None])[0]
streamer = DemucsStreamer(demucs, num_frames=args.num_frames)
out_rt = []
frame_size = streamer.total_length
with th.no_grad():
while x.shape[1] > 0:
out_rt.append(streamer.feed(x[:, :frame_size]))
x = x[:, frame_size:]
frame_size = streamer.demucs.total_stride
out_rt.append(streamer.flush())
out_rt = th.cat(out_rt, 1)
model_size = sum(p.numel() for p in demucs.parameters()) * 4 / 2**20
initial_lag = streamer.total_length / sr_ms
tpf = 1000 * streamer.time_per_frame
print(f"model size: {model_size:.1f}MB, ", end='')
print(f"delta batch/streaming: {th.norm(out - out_rt) / th.norm(out):.2%}")
print(f"initial lag: {initial_lag:.1f}ms, ", end='')
print(f"stride: {streamer.stride * args.num_frames / sr_ms:.1f}ms")
print(f"time per frame: {tpf:.1f}ms, ", end='')
rtf = (1000 * streamer.time_per_frame) / (streamer.stride / sr_ms)
print(f"RTF: {rtf:.2f}")
print(f"Total lag with computation: {initial_lag + tpf:.1f}ms")
if __name__ == "__main__":
test()
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# author: adefossez
import logging
import torch.hub
from .demucs import Demucs
from .utils import deserialize_model
logger = logging.getLogger(__name__)
ROOT = "https://dl.fbaipublicfiles.com/adiyoss/denoiser/"
DNS_48_URL = ROOT + "dns48-11decc9d8e3f0998.th"
DNS_64_URL = ROOT + "dns64-a7761ff99a7d5bb6.th"
MASTER_64_URL = ROOT + "master64-8a5dfb4bb92753dd.th"
def _demucs(pretrained, url, **kwargs):
model = Demucs(**kwargs)
if pretrained:
state_dict = torch.hub.load_state_dict_from_url(url, map_location='cpu')
model.load_state_dict(state_dict)
return model
def dns48(pretrained=True):
return _demucs(pretrained, DNS_48_URL, hidden=48)
def dns64(pretrained=True):
return _demucs(pretrained, DNS_64_URL, hidden=64)
def master64(pretrained=True):
return _demucs(pretrained, MASTER_64_URL, hidden=64)
def add_model_flags(parser):
group = parser.add_mutually_exclusive_group(required=False)
group.add_argument(
"-m", "--model_path", help="Path to local trained model."
)
group.add_argument(
"--dns48", action="store_true",
help="Use pre-trained real time H=48 model trained on DNS."
)
group.add_argument(
"--dns64", action="store_true",
help="Use pre-trained real time H=64 model trained on DNS."
)
group.add_argument(
"--master64", action="store_true",
help="Use pre-trained real time H=64 model trained on DNS and Valentini."
)
def get_model(args):
"""
Load local model package or torchhub pre-trained model.
"""
if args.model_path:
logger.info("Loading model from %s", args.model_path)
pkg = torch.load(args.model_path)
model = deserialize_model(pkg)
elif args.dns64:
logger.info("Loading pre-trained real time H=64 model trained on DNS.")
model = dns64()
elif args.master64:
logger.info(
"Loading pre-trained real time H=64 model trained on DNS and Valentini."
)
model = master64()
else:
logger.info("Loading pre-trained real time H=48 model trained on DNS.")
model = dns48()
logger.debug(model)
return model
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# author: adefossez
import math
import torch as th
from torch.nn import functional as F
def sinc(t):
"""sinc.
:param t: the input tensor
"""
return th.where(t == 0, th.tensor(1., device=t.device, dtype=t.dtype),
th.sin(t) / t)
def kernel_upsample2(zeros=56):
"""kernel_upsample2.
"""
win = th.hann_window(4 * zeros + 1, periodic=False)
winodd = win[1::2]
t = th.linspace(-zeros + 0.5, zeros - 0.5, 2 * zeros)
t *= math.pi
kernel = (sinc(t) * winodd).view(1, 1, -1)
return kernel
def upsample2(x, zeros=56):
"""
Upsampling the input by 2 using sinc interpolation.
Smith, Julius, and Phil Gossett. "A flexible sampling-rate conversion method."
ICASSP'84. IEEE International Conference on Acoustics, Speech, and Signal Processing.
Vol. 9. IEEE, 1984.
"""
*other, time = x.shape
kernel = kernel_upsample2(zeros).to(x)
out = F.conv1d(x.view(-1, 1, time), kernel, padding=zeros)[..., 1:].view(
*other, time
)
y = th.stack([x, out], dim=-1)
return y.view(*other, -1)
def kernel_downsample2(zeros=56):
"""kernel_downsample2.
"""
win = th.hann_window(4 * zeros + 1, periodic=False)
winodd = win[1::2]
t = th.linspace(-zeros + 0.5, zeros - 0.5, 2 * zeros)
t.mul_(math.pi)
kernel = (sinc(t) * winodd).view(1, 1, -1)
return kernel
def downsample2(x, zeros=56):
"""
Downsampling the input by 2 using sinc interpolation.
Smith, Julius, and Phil Gossett. "A flexible sampling-rate conversion method."
ICASSP'84. IEEE International Conference on Acoustics, Speech, and Signal Processing.
Vol. 9. IEEE, 1984.
"""
if x.shape[-1] % 2 != 0:
x = F.pad(x, (0, 1))
xeven = x[..., ::2]
xodd = x[..., 1::2]
*other, time = xodd.shape
kernel = kernel_downsample2(zeros).to(x)
out = xeven + F.conv1d(
xodd.view(-1, 1, time), kernel, padding=zeros
)[..., :-1].view(*other, time)
return out.view(*other, -1).mul(0.5)
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# author: adefossez
import functools
import logging
from contextlib import contextmanager
import inspect
import time
logger = logging.getLogger(__name__)
EPS = 1e-8
def capture_init(init):
"""capture_init.
Decorate `__init__` with this, and you can then
recover the *args and **kwargs passed to it in `self._init_args_kwargs`
"""
@functools.wraps(init)
def __init__(self, *args, **kwargs):
self._init_args_kwargs = (args, kwargs)
init(self, *args, **kwargs)
return __init__
def deserialize_model(package, strict=False):
"""deserialize_model.
"""
klass = package['class']
if strict:
model = klass(*package['args'], **package['kwargs'])
else:
sig = inspect.signature(klass)
kw = package['kwargs']
for key in list(kw):
if key not in sig.parameters:
logger.warning("Dropping inexistant parameter %s", key)
del kw[key]
model = klass(*package['args'], **kw)
model.load_state_dict(package['state'])
return model
def copy_state(state):
return {k: v.cpu().clone() for k, v in state.items()}
def serialize_model(model):
args, kwargs = model._init_args_kwargs
state = copy_state(model.state_dict())
return {"class": model.__class__, "args": args, "kwargs": kwargs, "state": state}
@contextmanager
def swap_state(model, state):
"""
Context manager that swaps the state of a model, e.g:
# model is in old state
with swap_state(model, new_state):
# model in new state
# model back to old state
"""
old_state = copy_state(model.state_dict())
model.load_state_dict(state)
try:
yield
finally:
model.load_state_dict(old_state)
def pull_metric(history, name):
out = []
for metrics in history:
if name in metrics:
out.append(metrics[name])
return out
class LogProgress:
"""
Sort of like tqdm but using log lines and not as real time.
Args:
- logger: logger obtained from `logging.getLogger`,
- iterable: iterable object to wrap
- updates (int): number of lines that will be printed, e.g.
if `updates=5`, log every 1/5th of the total length.
- total (int): length of the iterable, in case it does not support
`len`.
- name (str): prefix to use in the log.
- level: logging level (like `logging.INFO`).
"""
def __init__(self,
logger,
iterable,
updates=5,
total=None,
name="LogProgress",
level=logging.INFO):
self.iterable = iterable
self.total = total or len(iterable)
self.updates = updates
self.name = name
self.logger = logger
self.level = level
def update(self, **infos):
self._infos = infos
def __iter__(self):
self._iterator = iter(self.iterable)
self._index = -1
self._infos = {}
self._begin = time.time()
return self
def __next__(self):
self._index += 1
try:
value = next(self._iterator)
except StopIteration:
raise
else:
return value
finally:
log_every = max(1, self.total // self.updates)
# logging is delayed by 1 it, in order to have the metrics from update
if self._index >= 1 and self._index % log_every == 0:
self._log()
def _log(self):
self._speed = (1 + self._index) / (time.time() - self._begin)
infos = " | ".join(f"{k.capitalize()} {v}" for k, v in self._infos.items())
if self._speed < 1e-4:
speed = "oo sec/it"
elif self._speed < 0.1:
speed = f"{1/self._speed:.1f} sec/it"
else:
speed = f"{self._speed:.1f} it/sec"
out = f"{self.name} | {self._index}/{self.total} | {speed}"
if infos:
out += " | " + infos
self.logger.log(self.level, out)
def colorize(text, color):
"""
Display text with some ANSI color in the terminal.
"""
code = f"\033[{color}m"
restore = "\033[0m"
return "".join([code, text, restore])
def bold(text):
"""
Display text in bold in the terminal.
"""
return colorize(text, "1")
def cal_snr(lbl, est):
import torch
y = 10.0 * torch.log10(
torch.sum(lbl**2, dim=-1) / (torch.sum((est-lbl)**2, dim=-1) + EPS) +
EPS
)
return y
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import logging
from pathlib import Path
from collections import defaultdict
from typing import List, Dict, Tuple
import pandas as pd
import numpy as np
import torchaudio
from tqdm import tqdm
from examples.speech_to_text.data_utils import load_df_from_tsv, save_df_to_tsv
log = logging.getLogger(__name__)
SPLITS = ["train", "dev", "test"]
def get_top_n(
root: Path, n_speakers: int = 10, min_n_tokens: int = 5
) -> pd.DataFrame:
df = load_df_from_tsv(root / "validated.tsv")
df["n_tokens"] = [len(s.split()) for s in df["sentence"]]
df = df[df["n_tokens"] >= min_n_tokens]
df["n_frames"] = [
torchaudio.info((root / "clips" / p).as_posix()).num_frames
for p in tqdm(df["path"])
]
df["id"] = [Path(p).stem for p in df["path"]]
total_duration_ms = df.groupby("client_id")["n_frames"].agg(["sum"])
total_duration_ms = total_duration_ms.sort_values("sum", ascending=False)
top_n_total_duration_ms = total_duration_ms.head(n_speakers)
top_n_client_ids = set(top_n_total_duration_ms.index.tolist())
df_top_n = df[df["client_id"].isin(top_n_client_ids)]
return df_top_n
def get_splits(
df, train_split_ratio=0.99, speaker_in_all_splits=False, rand_seed=0
) -> Tuple[Dict[str, str], List[str]]:
np.random.seed(rand_seed)
dev_split_ratio = (1. - train_split_ratio) / 3
grouped = list(df.groupby("client_id"))
id_to_split = {}
for _, cur_df in tqdm(grouped):
cur_n_examples = len(cur_df)
if speaker_in_all_splits and cur_n_examples < 3:
continue
cur_n_train = int(cur_n_examples * train_split_ratio)
cur_n_dev = int(cur_n_examples * dev_split_ratio)
cur_n_test = cur_n_examples - cur_n_dev - cur_n_train
if speaker_in_all_splits and cur_n_dev * cur_n_test == 0:
cur_n_dev, cur_n_test = 1, 1
cur_n_train = cur_n_examples - cur_n_dev - cur_n_test
cur_indices = cur_df.index.tolist()
cur_shuffled_indices = np.random.permutation(cur_n_examples)
cur_shuffled_indices = [cur_indices[i] for i in cur_shuffled_indices]
cur_indices_by_split = {
"train": cur_shuffled_indices[:cur_n_train],
"dev": cur_shuffled_indices[cur_n_train: cur_n_train + cur_n_dev],
"test": cur_shuffled_indices[cur_n_train + cur_n_dev:]
}
for split in SPLITS:
for i in cur_indices_by_split[split]:
id_ = df["id"].loc[i]
id_to_split[id_] = split
return id_to_split, sorted(df["client_id"].unique())
def convert_to_wav(root: Path, filenames: List[str], target_sr=16_000):
out_root = root / "wav"
out_root.mkdir(exist_ok=True, parents=True)
print("Converting to WAV...")
for n in tqdm(filenames):
in_path = (root / "clips" / n).as_posix()
waveform, sr = torchaudio.load(in_path)
converted, converted_sr = torchaudio.sox_effects.apply_effects_tensor(
waveform, sr, [["rate", str(target_sr)], ["channels", "1"]]
)
out_path = (out_root / Path(n).with_suffix(".wav").name).as_posix()
torchaudio.save(out_path, converted, converted_sr, encoding="PCM_S",
bits_per_sample=16)
def process(args):
data_root = Path(args.data_root).absolute() / args.lang
# Generate TSV manifest
print("Generating manifest...")
df_top_n = get_top_n(data_root)
id_to_split, speakers = get_splits(df_top_n)
if args.convert_to_wav:
convert_to_wav(data_root, df_top_n["path"].tolist())
manifest_by_split = {split: defaultdict(list) for split in SPLITS}
for sample in tqdm(df_top_n.to_dict(orient="index").values()):
sample_id = sample["id"]
split = id_to_split[sample_id]
manifest_by_split[split]["id"].append(sample_id)
if args.convert_to_wav:
audio_path = data_root / "wav" / f"{sample_id}.wav"
else:
audio_path = data_root / "clips" / f"{sample_id}.mp3"
manifest_by_split[split]["audio"].append(audio_path.as_posix())
manifest_by_split[split]["n_frames"].append(sample["n_frames"])
manifest_by_split[split]["tgt_text"].append(sample["sentence"])
manifest_by_split[split]["speaker"].append(sample["client_id"])
manifest_by_split[split]["src_text"].append(sample["sentence"])
output_root = Path(args.output_manifest_root).absolute()
output_root.mkdir(parents=True, exist_ok=True)
for split in SPLITS:
save_df_to_tsv(
pd.DataFrame.from_dict(manifest_by_split[split]),
output_root / f"{split}.audio.tsv"
)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--data-root", "-d", required=True, type=str)
parser.add_argument("--output-manifest-root", "-m", required=True, type=str)
parser.add_argument("--lang", "-l", required=True, type=str)
parser.add_argument("--convert-to-wav", action="store_true")
args = parser.parse_args()
process(args)
if __name__ == "__main__":
main()
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import logging
from pathlib import Path
import shutil
from tempfile import NamedTemporaryFile
from collections import Counter, defaultdict
import pandas as pd
import torchaudio
from tqdm import tqdm
from fairseq.data.audio.audio_utils import convert_waveform
from examples.speech_to_text.data_utils import (
create_zip,
gen_config_yaml,
gen_vocab,
get_zip_manifest,
load_tsv_to_dicts,
save_df_to_tsv
)
from examples.speech_synthesis.data_utils import (
extract_logmel_spectrogram, extract_pitch, extract_energy, get_global_cmvn,
ipa_phonemize, get_mfa_alignment, get_unit_alignment
)
log = logging.getLogger(__name__)
def process(args):
assert "train" in args.splits
out_root = Path(args.output_root).absolute()
out_root.mkdir(exist_ok=True)
print("Fetching data...")
audio_manifest_root = Path(args.audio_manifest_root).absolute()
samples = []
for s in args.splits:
for e in load_tsv_to_dicts(audio_manifest_root / f"{s}.audio.tsv"):
e["split"] = s
samples.append(e)
sample_ids = [s["id"] for s in samples]
# Get alignment info
id_to_alignment = None
if args.textgrid_zip is not None:
assert args.id_to_units_tsv is None
id_to_alignment = get_mfa_alignment(
args.textgrid_zip, sample_ids, args.sample_rate, args.hop_length
)
elif args.id_to_units_tsv is not None:
# assume identical hop length on the unit sequence
id_to_alignment = get_unit_alignment(args.id_to_units_tsv, sample_ids)
# Extract features and pack features into ZIP
feature_name = "logmelspec80"
zip_path = out_root / f"{feature_name}.zip"
pitch_zip_path = out_root / "pitch.zip"
energy_zip_path = out_root / "energy.zip"
gcmvn_npz_path = out_root / "gcmvn_stats.npz"
if zip_path.exists() and gcmvn_npz_path.exists():
print(f"{zip_path} and {gcmvn_npz_path} exist.")
else:
feature_root = out_root / feature_name
feature_root.mkdir(exist_ok=True)
pitch_root = out_root / "pitch"
energy_root = out_root / "energy"
if args.add_fastspeech_targets:
pitch_root.mkdir(exist_ok=True)
energy_root.mkdir(exist_ok=True)
print("Extracting Mel spectrogram features...")
for sample in tqdm(samples):
waveform, sample_rate = torchaudio.load(sample["audio"])
waveform, sample_rate = convert_waveform(
waveform, sample_rate, normalize_volume=args.normalize_volume,
to_sample_rate=args.sample_rate
)
sample_id = sample["id"]
target_length = None
if id_to_alignment is not None:
a = id_to_alignment[sample_id]
target_length = sum(a.frame_durations)
if a.start_sec is not None and a.end_sec is not None:
start_frame = int(a.start_sec * sample_rate)
end_frame = int(a.end_sec * sample_rate)
waveform = waveform[:, start_frame: end_frame]
extract_logmel_spectrogram(
waveform, sample_rate, feature_root / f"{sample_id}.npy",
win_length=args.win_length, hop_length=args.hop_length,
n_fft=args.n_fft, n_mels=args.n_mels, f_min=args.f_min,
f_max=args.f_max, target_length=target_length
)
if args.add_fastspeech_targets:
assert id_to_alignment is not None
extract_pitch(
waveform, sample_rate, pitch_root / f"{sample_id}.npy",
hop_length=args.hop_length, log_scale=True,
phoneme_durations=id_to_alignment[sample_id].frame_durations
)
extract_energy(
waveform, energy_root / f"{sample_id}.npy",
hop_length=args.hop_length, n_fft=args.n_fft,
log_scale=True,
phoneme_durations=id_to_alignment[sample_id].frame_durations
)
print("ZIPing features...")
create_zip(feature_root, zip_path)
get_global_cmvn(feature_root, gcmvn_npz_path)
shutil.rmtree(feature_root)
if args.add_fastspeech_targets:
create_zip(pitch_root, pitch_zip_path)
shutil.rmtree(pitch_root)
create_zip(energy_root, energy_zip_path)
shutil.rmtree(energy_root)
print("Fetching ZIP manifest...")
audio_paths, audio_lengths = get_zip_manifest(zip_path)
pitch_paths, pitch_lengths, energy_paths, energy_lengths = [None] * 4
if args.add_fastspeech_targets:
pitch_paths, pitch_lengths = get_zip_manifest(pitch_zip_path)
energy_paths, energy_lengths = get_zip_manifest(energy_zip_path)
# Generate TSV manifest
print("Generating manifest...")
manifest_by_split = {split: defaultdict(list) for split in args.splits}
for sample in tqdm(samples):
sample_id, split = sample["id"], sample["split"]
normalized_utt = sample["tgt_text"]
if id_to_alignment is not None:
normalized_utt = " ".join(id_to_alignment[sample_id].tokens)
elif args.ipa_vocab:
normalized_utt = ipa_phonemize(
normalized_utt, lang=args.lang, use_g2p=args.use_g2p
)
manifest_by_split[split]["id"].append(sample_id)
manifest_by_split[split]["audio"].append(audio_paths[sample_id])
manifest_by_split[split]["n_frames"].append(audio_lengths[sample_id])
manifest_by_split[split]["tgt_text"].append(normalized_utt)
manifest_by_split[split]["speaker"].append(sample["speaker"])
manifest_by_split[split]["src_text"].append(sample["src_text"])
if args.add_fastspeech_targets:
assert id_to_alignment is not None
duration = " ".join(
str(d) for d in id_to_alignment[sample_id].frame_durations
)
manifest_by_split[split]["duration"].append(duration)
manifest_by_split[split]["pitch"].append(pitch_paths[sample_id])
manifest_by_split[split]["energy"].append(energy_paths[sample_id])
for split in args.splits:
save_df_to_tsv(
pd.DataFrame.from_dict(manifest_by_split[split]),
out_root / f"{split}.tsv"
)
# Generate vocab
vocab_name, spm_filename = None, None
if id_to_alignment is not None or args.ipa_vocab:
vocab = Counter()
for t in manifest_by_split["train"]["tgt_text"]:
vocab.update(t.split(" "))
vocab_name = "vocab.txt"
with open(out_root / vocab_name, "w") as f:
for s, c in vocab.most_common():
f.write(f"{s} {c}\n")
else:
spm_filename_prefix = "spm_char"
spm_filename = f"{spm_filename_prefix}.model"
with NamedTemporaryFile(mode="w") as f:
for t in manifest_by_split["train"]["tgt_text"]:
f.write(t + "\n")
f.flush() # needed to ensure gen_vocab sees dumped text
gen_vocab(Path(f.name), out_root / spm_filename_prefix, "char")
# Generate speaker list
speakers = sorted({sample["speaker"] for sample in samples})
speakers_path = out_root / "speakers.txt"
with open(speakers_path, "w") as f:
for speaker in speakers:
f.write(f"{speaker}\n")
# Generate config YAML
win_len_t = args.win_length / args.sample_rate
hop_len_t = args.hop_length / args.sample_rate
extra = {
"sample_rate": args.sample_rate,
"features": {
"type": "spectrogram+melscale+log",
"eps": 1e-2, "n_mels": args.n_mels, "n_fft": args.n_fft,
"window_fn": "hann", "win_length": args.win_length,
"hop_length": args.hop_length, "sample_rate": args.sample_rate,
"win_len_t": win_len_t, "hop_len_t": hop_len_t,
"f_min": args.f_min, "f_max": args.f_max,
"n_stft": args.n_fft // 2 + 1
}
}
if len(speakers) > 1:
extra["speaker_set_filename"] = "speakers.txt"
gen_config_yaml(
out_root, spm_filename=spm_filename, vocab_name=vocab_name,
audio_root=out_root.as_posix(), input_channels=None,
input_feat_per_channel=None, specaugment_policy=None,
cmvn_type="global", gcmvn_path=gcmvn_npz_path, extra=extra
)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--audio-manifest-root", "-m", required=True, type=str)
parser.add_argument("--output-root", "-o", required=True, type=str)
parser.add_argument("--splits", "-s", type=str, nargs="+",
default=["train", "dev", "test"])
parser.add_argument("--ipa-vocab", action="store_true")
parser.add_argument("--use-g2p", action="store_true")
parser.add_argument("--lang", type=str, default="en-us")
parser.add_argument("--win-length", type=int, default=1024)
parser.add_argument("--hop-length", type=int, default=256)
parser.add_argument("--n-fft", type=int, default=1024)
parser.add_argument("--n-mels", type=int, default=80)
parser.add_argument("--f-min", type=int, default=20)
parser.add_argument("--f-max", type=int, default=8000)
parser.add_argument("--sample-rate", type=int, default=22050)
parser.add_argument("--normalize-volume", "-n", action="store_true")
parser.add_argument("--textgrid-zip", type=str, default=None)
parser.add_argument("--id-to-units-tsv", type=str, default=None)
parser.add_argument("--add-fastspeech-targets", action="store_true")
args = parser.parse_args()
process(args)
if __name__ == "__main__":
main()
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import logging
from pathlib import Path
from collections import defaultdict
import pandas as pd
from torchaudio.datasets import LJSPEECH
from tqdm import tqdm
from examples.speech_to_text.data_utils import save_df_to_tsv
log = logging.getLogger(__name__)
SPLITS = ["train", "dev", "test"]
def process(args):
out_root = Path(args.output_data_root).absolute()
out_root.mkdir(parents=True, exist_ok=True)
# Generate TSV manifest
print("Generating manifest...")
# following FastSpeech's splits
dataset = LJSPEECH(out_root.as_posix(), download=True)
id_to_split = {}
for x in dataset._flist:
id_ = x[0]
speaker = id_.split("-")[0]
id_to_split[id_] = {
"LJ001": "test", "LJ002": "test", "LJ003": "dev"
}.get(speaker, "train")
manifest_by_split = {split: defaultdict(list) for split in SPLITS}
progress = tqdm(enumerate(dataset), total=len(dataset))
for i, (waveform, _, utt, normalized_utt) in progress:
sample_id = dataset._flist[i][0]
split = id_to_split[sample_id]
manifest_by_split[split]["id"].append(sample_id)
audio_path = f"{dataset._path}/{sample_id}.wav"
manifest_by_split[split]["audio"].append(audio_path)
manifest_by_split[split]["n_frames"].append(len(waveform[0]))
manifest_by_split[split]["tgt_text"].append(normalized_utt)
manifest_by_split[split]["speaker"].append("ljspeech")
manifest_by_split[split]["src_text"].append(utt)
manifest_root = Path(args.output_manifest_root).absolute()
manifest_root.mkdir(parents=True, exist_ok=True)
for split in SPLITS:
save_df_to_tsv(
pd.DataFrame.from_dict(manifest_by_split[split]),
manifest_root / f"{split}.audio.tsv"
)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--output-data-root", "-d", required=True, type=str)
parser.add_argument("--output-manifest-root", "-m", required=True, type=str)
args = parser.parse_args()
process(args)
if __name__ == "__main__":
main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment