Commit 799a38c5 authored by chenzk's avatar chenzk
Browse files

v1.0

parents
Pipeline #616 failed with stages
in 0 seconds
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import argparse
from collections import defaultdict
from itertools import chain
from pathlib import Path
import numpy as np
import torchaudio
import torchaudio.sox_effects as ta_sox
import yaml
from tqdm import tqdm
from examples.speech_to_text.data_utils import load_tsv_to_dicts
from examples.speech_synthesis.preprocessing.speaker_embedder import SpkrEmbedder
def extract_embedding(audio_path, embedder):
wav, sr = torchaudio.load(audio_path) # 2D
if sr != embedder.RATE:
wav, sr = ta_sox.apply_effects_tensor(
wav, sr, [["rate", str(embedder.RATE)]]
)
try:
emb = embedder([wav[0].cuda().float()]).cpu().numpy()
except RuntimeError:
emb = None
return emb
def process(args):
print("Fetching data...")
raw_manifest_root = Path(args.raw_manifest_root).absolute()
samples = [load_tsv_to_dicts(raw_manifest_root / (s + ".tsv"))
for s in args.splits]
samples = list(chain(*samples))
with open(args.config, "r") as f:
config = yaml.load(f, Loader=yaml.FullLoader)
with open(f"{config['audio_root']}/{config['speaker_set_filename']}") as f:
speaker_to_id = {r.strip(): i for i, r in enumerate(f)}
embedder = SpkrEmbedder(args.ckpt).cuda()
speaker_to_cnt = defaultdict(float)
speaker_to_emb = defaultdict(float)
for sample in tqdm(samples, desc="extract emb"):
emb = extract_embedding(sample["audio"], embedder)
if emb is not None:
speaker_to_cnt[sample["speaker"]] += 1
speaker_to_emb[sample["speaker"]] += emb
if len(speaker_to_emb) != len(speaker_to_id):
missed = set(speaker_to_id) - set(speaker_to_emb.keys())
print(
f"WARNING: missing embeddings for {len(missed)} speaker:\n{missed}"
)
speaker_emb_mat = np.zeros((len(speaker_to_id), len(emb)), float)
for speaker in speaker_to_emb:
idx = speaker_to_id[speaker]
emb = speaker_to_emb[speaker]
cnt = speaker_to_cnt[speaker]
speaker_emb_mat[idx, :] = emb / cnt
speaker_emb_name = "speaker_emb.npy"
speaker_emb_path = f"{config['audio_root']}/{speaker_emb_name}"
np.save(speaker_emb_path, speaker_emb_mat)
config["speaker_emb_filename"] = speaker_emb_name
with open(args.new_config, "w") as f:
yaml.dump(config, f)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--raw-manifest-root", "-m", required=True, type=str)
parser.add_argument("--splits", "-s", type=str, nargs="+",
default=["train"])
parser.add_argument("--config", "-c", required=True, type=str)
parser.add_argument("--new-config", "-n", required=True, type=str)
parser.add_argument("--ckpt", required=True, type=str,
help="speaker embedder checkpoint")
args = parser.parse_args()
process(args)
if __name__ == "__main__":
main()
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import logging
import numpy as np
import re
from pathlib import Path
from collections import defaultdict
import pandas as pd
from torchaudio.datasets import VCTK
from tqdm import tqdm
from examples.speech_to_text.data_utils import save_df_to_tsv
log = logging.getLogger(__name__)
SPLITS = ["train", "dev", "test"]
def normalize_text(text):
return re.sub(r"[^a-zA-Z.?!,'\- ]", '', text)
def process(args):
out_root = Path(args.output_data_root).absolute()
out_root.mkdir(parents=True, exist_ok=True)
# Generate TSV manifest
print("Generating manifest...")
dataset = VCTK(out_root.as_posix(), download=False)
ids = list(dataset._walker)
np.random.seed(args.seed)
np.random.shuffle(ids)
n_train = len(ids) - args.n_dev - args.n_test
_split = ["train"] * n_train + ["dev"] * args.n_dev + ["test"] * args.n_test
id_to_split = dict(zip(ids, _split))
manifest_by_split = {split: defaultdict(list) for split in SPLITS}
progress = tqdm(enumerate(dataset), total=len(dataset))
for i, (waveform, _, text, speaker_id, _) in progress:
sample_id = dataset._walker[i]
_split = id_to_split[sample_id]
audio_dir = Path(dataset._path) / dataset._folder_audio / speaker_id
audio_path = audio_dir / f"{sample_id}.wav"
text = normalize_text(text)
manifest_by_split[_split]["id"].append(sample_id)
manifest_by_split[_split]["audio"].append(audio_path.as_posix())
manifest_by_split[_split]["n_frames"].append(len(waveform[0]))
manifest_by_split[_split]["tgt_text"].append(text)
manifest_by_split[_split]["speaker"].append(speaker_id)
manifest_by_split[_split]["src_text"].append(text)
manifest_root = Path(args.output_manifest_root).absolute()
manifest_root.mkdir(parents=True, exist_ok=True)
for _split in SPLITS:
save_df_to_tsv(
pd.DataFrame.from_dict(manifest_by_split[_split]),
manifest_root / f"{_split}.audio.tsv"
)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--output-data-root", "-d", required=True, type=str)
parser.add_argument("--output-manifest-root", "-m", required=True, type=str)
parser.add_argument("--n-dev", default=50, type=int)
parser.add_argument("--n-test", default=100, type=int)
parser.add_argument("--seed", "-s", default=1234, type=int)
args = parser.parse_args()
process(args)
if __name__ == "__main__":
main()
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import librosa
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data
import torchaudio
EMBEDDER_PARAMS = {
'num_mels': 40,
'n_fft': 512,
'emb_dim': 256,
'lstm_hidden': 768,
'lstm_layers': 3,
'window': 80,
'stride': 40,
}
def set_requires_grad(nets, requires_grad=False):
"""Set requies_grad=Fasle for all the networks to avoid unnecessary
computations
Parameters:
nets (network list) -- a list of networks
requires_grad (bool) -- whether the networks require gradients or not
"""
if not isinstance(nets, list):
nets = [nets]
for net in nets:
if net is not None:
for param in net.parameters():
param.requires_grad = requires_grad
class LinearNorm(nn.Module):
def __init__(self, hp):
super(LinearNorm, self).__init__()
self.linear_layer = nn.Linear(hp["lstm_hidden"], hp["emb_dim"])
def forward(self, x):
return self.linear_layer(x)
class SpeechEmbedder(nn.Module):
def __init__(self, hp):
super(SpeechEmbedder, self).__init__()
self.lstm = nn.LSTM(hp["num_mels"],
hp["lstm_hidden"],
num_layers=hp["lstm_layers"],
batch_first=True)
self.proj = LinearNorm(hp)
self.hp = hp
def forward(self, mel):
# (num_mels, T) -> (num_mels, T', window)
mels = mel.unfold(1, self.hp["window"], self.hp["stride"])
mels = mels.permute(1, 2, 0) # (T', window, num_mels)
x, _ = self.lstm(mels) # (T', window, lstm_hidden)
x = x[:, -1, :] # (T', lstm_hidden), use last frame only
x = self.proj(x) # (T', emb_dim)
x = x / torch.norm(x, p=2, dim=1, keepdim=True) # (T', emb_dim)
x = x.mean(dim=0)
if x.norm(p=2) != 0:
x = x / x.norm(p=2)
return x
class SpkrEmbedder(nn.Module):
RATE = 16000
def __init__(
self,
embedder_path,
embedder_params=EMBEDDER_PARAMS,
rate=16000,
hop_length=160,
win_length=400,
pad=False,
):
super(SpkrEmbedder, self).__init__()
embedder_pt = torch.load(embedder_path, map_location="cpu")
self.embedder = SpeechEmbedder(embedder_params)
self.embedder.load_state_dict(embedder_pt)
self.embedder.eval()
set_requires_grad(self.embedder, requires_grad=False)
self.embedder_params = embedder_params
self.register_buffer('mel_basis', torch.from_numpy(
librosa.filters.mel(
sr=self.RATE,
n_fft=self.embedder_params["n_fft"],
n_mels=self.embedder_params["num_mels"])
)
)
self.resample = None
if rate != self.RATE:
self.resample = torchaudio.transforms.Resample(rate, self.RATE)
self.hop_length = hop_length
self.win_length = win_length
self.pad = pad
def get_mel(self, y):
if self.pad and y.shape[-1] < 14000:
y = F.pad(y, (0, 14000 - y.shape[-1]))
window = torch.hann_window(self.win_length).to(y)
y = torch.stft(y, n_fft=self.embedder_params["n_fft"],
hop_length=self.hop_length,
win_length=self.win_length,
window=window)
magnitudes = torch.norm(y, dim=-1, p=2) ** 2
mel = torch.log10(self.mel_basis @ magnitudes + 1e-6)
return mel
def forward(self, inputs):
dvecs = []
for wav in inputs:
mel = self.get_mel(wav)
if mel.dim() == 3:
mel = mel.squeeze(0)
dvecs += [self.embedder(mel)]
dvecs = torch.stack(dvecs)
dvec = torch.mean(dvecs, dim=0)
dvec = dvec / torch.norm(dvec)
return dvec
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import collections
import contextlib
import wave
try:
import webrtcvad
except ImportError:
raise ImportError("Please install py-webrtcvad: pip install webrtcvad")
import argparse
import os
import logging
from tqdm import tqdm
AUDIO_SUFFIX = '.wav'
FS_MS = 30
SCALE = 6e-5
THRESHOLD = 0.3
def read_wave(path):
"""Reads a .wav file.
Takes the path, and returns (PCM audio data, sample rate).
"""
with contextlib.closing(wave.open(path, 'rb')) as wf:
num_channels = wf.getnchannels()
assert num_channels == 1
sample_width = wf.getsampwidth()
assert sample_width == 2
sample_rate = wf.getframerate()
assert sample_rate in (8000, 16000, 32000, 48000)
pcm_data = wf.readframes(wf.getnframes())
return pcm_data, sample_rate
def write_wave(path, audio, sample_rate):
"""Writes a .wav file.
Takes path, PCM audio data, and sample rate.
"""
with contextlib.closing(wave.open(path, 'wb')) as wf:
wf.setnchannels(1)
wf.setsampwidth(2)
wf.setframerate(sample_rate)
wf.writeframes(audio)
class Frame(object):
"""Represents a "frame" of audio data."""
def __init__(self, bytes, timestamp, duration):
self.bytes = bytes
self.timestamp = timestamp
self.duration = duration
def frame_generator(frame_duration_ms, audio, sample_rate):
"""Generates audio frames from PCM audio data.
Takes the desired frame duration in milliseconds, the PCM data, and
the sample rate.
Yields Frames of the requested duration.
"""
n = int(sample_rate * (frame_duration_ms / 1000.0) * 2)
offset = 0
timestamp = 0.0
duration = (float(n) / sample_rate) / 2.0
while offset + n < len(audio):
yield Frame(audio[offset:offset + n], timestamp, duration)
timestamp += duration
offset += n
def vad_collector(sample_rate, frame_duration_ms,
padding_duration_ms, vad, frames):
"""Filters out non-voiced audio frames.
Given a webrtcvad.Vad and a source of audio frames, yields only
the voiced audio.
Uses a padded, sliding window algorithm over the audio frames.
When more than 90% of the frames in the window are voiced (as
reported by the VAD), the collector triggers and begins yielding
audio frames. Then the collector waits until 90% of the frames in
the window are unvoiced to detrigger.
The window is padded at the front and back to provide a small
amount of silence or the beginnings/endings of speech around the
voiced frames.
Arguments:
sample_rate - The audio sample rate, in Hz.
frame_duration_ms - The frame duration in milliseconds.
padding_duration_ms - The amount to pad the window, in milliseconds.
vad - An instance of webrtcvad.Vad.
frames - a source of audio frames (sequence or generator).
Returns: A generator that yields PCM audio data.
"""
num_padding_frames = int(padding_duration_ms / frame_duration_ms)
# We use a deque for our sliding window/ring buffer.
ring_buffer = collections.deque(maxlen=num_padding_frames)
# We have two states: TRIGGERED and NOTTRIGGERED. We start in the
# NOTTRIGGERED state.
triggered = False
voiced_frames = []
for frame in frames:
is_speech = vad.is_speech(frame.bytes, sample_rate)
# sys.stdout.write('1' if is_speech else '0')
if not triggered:
ring_buffer.append((frame, is_speech))
num_voiced = len([f for f, speech in ring_buffer if speech])
# If we're NOTTRIGGERED and more than 90% of the frames in
# the ring buffer are voiced frames, then enter the
# TRIGGERED state.
if num_voiced > 0.9 * ring_buffer.maxlen:
triggered = True
# We want to yield all the audio we see from now until
# we are NOTTRIGGERED, but we have to start with the
# audio that's already in the ring buffer.
for f, _ in ring_buffer:
voiced_frames.append(f)
ring_buffer.clear()
else:
# We're in the TRIGGERED state, so collect the audio data
# and add it to the ring buffer.
voiced_frames.append(frame)
ring_buffer.append((frame, is_speech))
num_unvoiced = len([f for f, speech in ring_buffer if not speech])
# If more than 90% of the frames in the ring buffer are
# unvoiced, then enter NOTTRIGGERED and yield whatever
# audio we've collected.
if num_unvoiced > 0.9 * ring_buffer.maxlen:
triggered = False
yield [b''.join([f.bytes for f in voiced_frames]),
voiced_frames[0].timestamp, voiced_frames[-1].timestamp]
ring_buffer.clear()
voiced_frames = []
# If we have any leftover voiced audio when we run out of input,
# yield it.
if voiced_frames:
yield [b''.join([f.bytes for f in voiced_frames]),
voiced_frames[0].timestamp, voiced_frames[-1].timestamp]
def main(args):
# create output folder
try:
cmd = f"mkdir -p {args.out_path}"
os.system(cmd)
except Exception:
logging.error("Can not create output folder")
exit(-1)
# build vad object
vad = webrtcvad.Vad(int(args.agg))
# iterating over wavs in dir
for file in tqdm(os.listdir(args.in_path)):
if file.endswith(AUDIO_SUFFIX):
audio_inpath = os.path.join(args.in_path, file)
audio_outpath = os.path.join(args.out_path, file)
audio, sample_rate = read_wave(audio_inpath)
frames = frame_generator(FS_MS, audio, sample_rate)
frames = list(frames)
segments = vad_collector(sample_rate, FS_MS, 300, vad, frames)
merge_segments = list()
timestamp_start = 0.0
timestamp_end = 0.0
# removing start, end, and long sequences of sils
for i, segment in enumerate(segments):
merge_segments.append(segment[0])
if i and timestamp_start:
sil_duration = segment[1] - timestamp_end
if sil_duration > THRESHOLD:
merge_segments.append(int(THRESHOLD / SCALE)*(b'\x00'))
else:
merge_segments.append(int((sil_duration / SCALE))*(b'\x00'))
timestamp_start = segment[1]
timestamp_end = segment[2]
segment = b''.join(merge_segments)
write_wave(audio_outpath, segment, sample_rate)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Apply vad to a file of fils.')
parser.add_argument('in_path', type=str, help='Path to the input files')
parser.add_argument('out_path', type=str,
help='Path to save the processed files')
parser.add_argument('--agg', type=int, default=3,
help='The level of aggressiveness of the VAD: [0-3]')
args = parser.parse_args()
main(args)
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import numpy as np
import torch
from scipy.interpolate import interp1d
import torchaudio
from fairseq.tasks.text_to_speech import (
batch_compute_distortion, compute_rms_dist
)
def batch_mel_spectral_distortion(
y1, y2, sr, normalize_type="path", mel_fn=None
):
"""
https://arxiv.org/pdf/2011.03568.pdf
Same as Mel Cepstral Distortion, but computed on log-mel spectrograms.
"""
if mel_fn is None or mel_fn.sample_rate != sr:
mel_fn = torchaudio.transforms.MelSpectrogram(
sr, n_fft=int(0.05 * sr), win_length=int(0.05 * sr),
hop_length=int(0.0125 * sr), f_min=20, n_mels=80,
window_fn=torch.hann_window
).to(y1[0].device)
offset = 1e-6
return batch_compute_distortion(
y1, y2, sr, lambda y: torch.log(mel_fn(y) + offset).transpose(-1, -2),
compute_rms_dist, normalize_type
)
# This code is based on
# "https://github.com/bastibe/MAPS-Scripts/blob/master/helper.py"
def _same_t_in_true_and_est(func):
def new_func(true_t, true_f, est_t, est_f):
assert type(true_t) is np.ndarray
assert type(true_f) is np.ndarray
assert type(est_t) is np.ndarray
assert type(est_f) is np.ndarray
interpolated_f = interp1d(
est_t, est_f, bounds_error=False, kind='nearest', fill_value=0
)(true_t)
return func(true_t, true_f, true_t, interpolated_f)
return new_func
@_same_t_in_true_and_est
def gross_pitch_error(true_t, true_f, est_t, est_f):
"""The relative frequency in percent of pitch estimates that are
outside a threshold around the true pitch. Only frames that are
considered pitched by both the ground truth and the estimator (if
applicable) are considered.
"""
correct_frames = _true_voiced_frames(true_t, true_f, est_t, est_f)
gross_pitch_error_frames = _gross_pitch_error_frames(
true_t, true_f, est_t, est_f
)
return np.sum(gross_pitch_error_frames) / np.sum(correct_frames)
def _gross_pitch_error_frames(true_t, true_f, est_t, est_f, eps=1e-8):
voiced_frames = _true_voiced_frames(true_t, true_f, est_t, est_f)
true_f_p_eps = [x + eps for x in true_f]
pitch_error_frames = np.abs(est_f / true_f_p_eps - 1) > 0.2
return voiced_frames & pitch_error_frames
def _true_voiced_frames(true_t, true_f, est_t, est_f):
return (est_f != 0) & (true_f != 0)
def _voicing_decision_error_frames(true_t, true_f, est_t, est_f):
return (est_f != 0) != (true_f != 0)
@_same_t_in_true_and_est
def f0_frame_error(true_t, true_f, est_t, est_f):
gross_pitch_error_frames = _gross_pitch_error_frames(
true_t, true_f, est_t, est_f
)
voicing_decision_error_frames = _voicing_decision_error_frames(
true_t, true_f, est_t, est_f
)
return (np.sum(gross_pitch_error_frames) +
np.sum(voicing_decision_error_frames)) / (len(true_t))
@_same_t_in_true_and_est
def voicing_decision_error(true_t, true_f, est_t, est_f):
voicing_decision_error_frames = _voicing_decision_error_frames(
true_t, true_f, est_t, est_f
)
return np.sum(voicing_decision_error_frames) / (len(true_t))
# Joint Speech Text training in Fairseq
An extension of Fairseq s2t project with the speech to text task enhanced by the co-trained text to text mapping task. More details about Fairseq s2t can be found [here](../speech_to_text/README.md)
## Examples
Examples of speech text joint training in fairseq
- [English-to-German MuST-C model](docs/ende-mustc.md)
- [IWSLT 2021 Multilingual Speech Translation](docs/iwslt2021.md)
## Citation
Please cite as:
```
@inproceedings{Tang2021AGM,
title={A General Multi-Task Learning Framework to Leverage Text Data for Speech to Text Tasks},
author={Yun Tang and J. Pino and Changhan Wang and Xutai Ma and Dmitriy Genzel},
booktitle={ICASSP},
year={2021}
}
@inproceedings{Tang2021IST,
title = {Improving Speech Translation by Understanding and Learning from the Auxiliary Text Translation Task},
author = {Yun Tang and Juan Pino and Xian Li and Changhan Wang and Dmitriy Genzel},
booktitle = {ACL},
year = {2021},
}
@inproceedings{Tang2021FST,
title = {FST: the FAIR Speech Translation System for the IWSLT21 Multilingual Shared Task},
author = {Yun Tang and Hongyu Gong and Xian Li and Changhan Wang and Juan Pino and Holger Schwenk and Naman Goyal},
booktitle = {IWSLT},
year = {2021},
}
@inproceedings{wang2020fairseqs2t,
title = {fairseq S2T: Fast Speech-to-Text Modeling with fairseq},
author = {Changhan Wang and Yun Tang and Xutai Ma and Anne Wu and Dmytro Okhonko and Juan Pino},
booktitle = {Proceedings of the 2020 Conference of the Asian Chapter of the Association for Computational Linguistics (AACL): System Demonstrations},
year = {2020},
}
@inproceedings{ott2019fairseq,
title = {fairseq: A Fast, Extensible Toolkit for Sequence Modeling},
author = {Myle Ott and Sergey Edunov and Alexei Baevski and Angela Fan and Sam Gross and Nathan Ng and David Grangier and Michael Auli},
booktitle = {Proceedings of NAACL-HLT 2019: Demonstrations},
year = {2019},
}
```
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from . import tasks, criterions, models # noqa
"(Applause) NOISE
"(Laughter) VOICE
"(Laughter)" VOICE
(Applause) NOISE
(Applause). NOISE
(Audience) VOICE
(Audio) NOISE
(Beat) NOISE
(Beatboxing) VOICE
(Beep) NOISE
(Beeps) NOISE
(Cheering) VOICE
(Cheers) VOICE
(Claps) NOISE
(Clicking) NOISE
(Clunk) NOISE
(Coughs) NOISE
(Drums) NOISE
(Explosion) NOISE
(Gasps) VOICE
(Guitar) NOISE
(Honk) NOISE
(Laugher) VOICE
(Laughing) VOICE
(Laughs) VOICE
(Laughter) VOICE
(Laughter). VOICE
(Laughter)... VOICE
(Mumbling) VOICE
(Music) NOISE
(Noise) NOISE
(Recording) VOICE
(Ringing) NOISE
(Shouts) VOICE
(Sigh) VOICE
(Sighs) VOICE
(Silence) NOISE
(Singing) VOICE
(Sings) VOICE
(Spanish) VOICE
(Static) NOISE
(Tones) NOISE
(Trumpet) NOISE
(Video) NOISE
(Video): NOISE
(Voice-over) NOISE
(Whistle) NOISE
(Whistling) NOISE
(video): NOISE
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment