Commit 18d27e00 authored by wangwei990215's avatar wangwei990215
Browse files

initial commit

parent 541f4c7a
import os.path as op
from typing import BinaryIO, Optional, Tuple, Union
import numpy as np
def get_waveform(
path_or_fp: Union[str, BinaryIO], normalization=True
) -> Tuple[np.ndarray, int]:
"""Get the waveform and sample rate of a 16-bit mono-channel WAV or FLAC.
Args:
path_or_fp (str or BinaryIO): the path or file-like object
normalization (bool): Normalize values to [-1, 1] (Default: True)
"""
if isinstance(path_or_fp, str):
ext = op.splitext(op.basename(path_or_fp))[1]
if ext not in {".flac", ".wav"}:
raise ValueError(f"Unsupported audio format: {ext}")
try:
import soundfile as sf
except ImportError:
raise ImportError("Please install soundfile to load WAV/FLAC file")
waveform, sample_rate = sf.read(path_or_fp, dtype="float32")
if not normalization:
waveform *= 2 ** 15 # denormalized to 16-bit signed integers
return waveform, sample_rate
def _get_kaldi_fbank(waveform, sample_rate, n_bins=80) -> Optional[np.ndarray]:
"""Get mel-filter bank features via PyKaldi."""
try:
from kaldi.feat.mel import MelBanksOptions
from kaldi.feat.fbank import FbankOptions, Fbank
from kaldi.feat.window import FrameExtractionOptions
from kaldi.matrix import Vector
mel_opts = MelBanksOptions()
mel_opts.num_bins = n_bins
frame_opts = FrameExtractionOptions()
frame_opts.samp_freq = sample_rate
opts = FbankOptions()
opts.mel_opts = mel_opts
opts.frame_opts = frame_opts
fbank = Fbank(opts=opts)
features = fbank.compute(Vector(waveform), 1.0).numpy()
return features
except ImportError:
return None
def _get_torchaudio_fbank(waveform, sample_rate, n_bins=80) -> Optional[np.ndarray]:
"""Get mel-filter bank features via TorchAudio."""
try:
import torch
import torchaudio.compliance.kaldi as ta_kaldi
waveform = torch.from_numpy(waveform).unsqueeze(0)
features = ta_kaldi.fbank(
waveform, num_mel_bins=n_bins, sample_frequency=sample_rate
)
return features.numpy()
except ImportError:
return None
def get_fbank(path_or_fp: Union[str, BinaryIO], n_bins=80) -> np.ndarray:
"""Get mel-filter bank features via PyKaldi or TorchAudio. Prefer PyKaldi
(faster CPP implementation) to TorchAudio (Python implementation). Note that
Kaldi/TorchAudio requires 16-bit signed integers as inputs and hence the
waveform should not be normalized."""
sound, sample_rate = get_waveform(path_or_fp, normalization=False)
features = _get_kaldi_fbank(sound, sample_rate, n_bins)
if features is None:
features = _get_torchaudio_fbank(sound, sample_rate, n_bins)
if features is None:
raise ImportError(
"Please install pyKaldi or torchaudio to enable "
"online filterbank feature extraction"
)
return features
import importlib
import os
from abc import ABC, abstractmethod
from typing import Dict, Optional
class AudioFeatureTransform(ABC):
@classmethod
@abstractmethod
def from_config_dict(cls, config: Optional[Dict] = None):
pass
AUDIO_FEATURE_TRANSFORM_REGISTRY = {}
AUDIO_FEATURE_TRANSFORM_CLASS_NAMES = set()
def register_audio_feature_transform(name):
def register_audio_feature_transform_cls(cls):
if name in AUDIO_FEATURE_TRANSFORM_REGISTRY:
raise ValueError(f"Cannot register duplicate transform ({name})")
if not issubclass(cls, AudioFeatureTransform):
raise ValueError(
f"Transform ({name}: {cls.__name__}) must extend "
"AudioFeatureTransform"
)
if cls.__name__ in AUDIO_FEATURE_TRANSFORM_CLASS_NAMES:
raise ValueError(
f"Cannot register audio feature transform with duplicate "
f"class name ({cls.__name__})"
)
AUDIO_FEATURE_TRANSFORM_REGISTRY[name] = cls
AUDIO_FEATURE_TRANSFORM_CLASS_NAMES.add(cls.__name__)
return cls
return register_audio_feature_transform_cls
def get_audio_feature_transform(name):
return AUDIO_FEATURE_TRANSFORM_REGISTRY[name]
transforms_dir = os.path.dirname(__file__)
for file in os.listdir(transforms_dir):
path = os.path.join(transforms_dir, file)
if (
not file.startswith("_")
and not file.startswith(".")
and (file.endswith(".py") or os.path.isdir(path))
):
name = file[: file.find(".py")] if file.endswith(".py") else file
importlib.import_module("fairseq.data.audio.feature_transforms." + name)
class CompositeAudioFeatureTransform(AudioFeatureTransform):
@classmethod
def from_config_dict(cls, config=None):
_config = {} if config is None else config
_transforms = _config.get("transforms")
if _transforms is None:
return None
transforms = [
get_audio_feature_transform(_t).from_config_dict(_config.get(_t))
for _t in _transforms
]
return CompositeAudioFeatureTransform(transforms)
def __init__(self, transforms):
self.transforms = [t for t in transforms if t is not None]
def __call__(self, x):
for t in self.transforms:
x = t(x)
return x
def __repr__(self):
format_string = (
[self.__class__.__name__ + "("]
+ [f" {t.__repr__()}" for t in self.transforms]
+ [")"]
)
return "\n".join(format_string)
import numpy as np
from fairseq.data.audio.feature_transforms import (
AudioFeatureTransform,
register_audio_feature_transform,
)
@register_audio_feature_transform("global_cmvn")
class GlobalCMVN(AudioFeatureTransform):
"""Global CMVN (cepstral mean and variance normalization). The global mean
and variance need to be pre-computed and stored in NumPy format (.npz)."""
@classmethod
def from_config_dict(cls, config=None):
_config = {} if config is None else config
return GlobalCMVN(_config.get("stats_npz_path"))
def __init__(self, stats_npz_path):
stats = np.load(stats_npz_path)
self.mean, self.std = stats["mean"], stats["std"]
def __call__(self, x):
x = np.subtract(x, self.mean)
x = np.divide(x, self.std)
return x
import math
import numbers
from typing import Optional
import numpy as np
from fairseq.data.audio.feature_transforms import (
AudioFeatureTransform,
register_audio_feature_transform,
)
@register_audio_feature_transform("specaugment")
class SpecAugmentTransform(AudioFeatureTransform):
"""SpecAugment (https://arxiv.org/abs/1904.08779)"""
@classmethod
def from_config_dict(cls, config=None):
_config = {} if config is None else config
return SpecAugmentTransform(
_config.get("time_warp_W", 0),
_config.get("freq_mask_N", 0),
_config.get("freq_mask_F", 0),
_config.get("time_mask_N", 0),
_config.get("time_mask_T", 0),
_config.get("time_mask_p", 0.0),
_config.get("mask_value", None),
)
def __init__(
self,
time_warp_w: int = 0,
freq_mask_n: int = 0,
freq_mask_f: int = 0,
time_mask_n: int = 0,
time_mask_t: int = 0,
time_mask_p: float = 0.0,
mask_value: Optional[float] = 0.0,
):
# Sanity checks
assert mask_value is None or isinstance(
mask_value, numbers.Number
), f"mask_value (type: {type(mask_value)}) must be None or a number"
if freq_mask_n > 0:
assert freq_mask_f > 0, (
f"freq_mask_F ({freq_mask_f}) "
f"must be larger than 0 when doing freq masking."
)
if time_mask_n > 0:
assert time_mask_t > 0, (
f"time_mask_T ({time_mask_t}) must be larger than 0 when "
f"doing time masking."
)
self.time_warp_w = time_warp_w
self.freq_mask_n = freq_mask_n
self.freq_mask_f = freq_mask_f
self.time_mask_n = time_mask_n
self.time_mask_t = time_mask_t
self.time_mask_p = time_mask_p
self.mask_value = mask_value
def __repr__(self):
return (
self.__class__.__name__
+ "("
+ ", ".join(
[
f"time_warp_w={self.time_warp_w}",
f"freq_mask_n={self.freq_mask_n}",
f"freq_mask_f={self.freq_mask_f}",
f"time_mask_n={self.time_mask_n}",
f"time_mask_t={self.time_mask_t}",
f"time_mask_p={self.time_mask_p}",
]
)
+ ")"
)
def __call__(self, spectrogram):
assert len(spectrogram.shape) == 2, "spectrogram must be a 2-D tensor."
distorted = spectrogram.copy() # make a copy of input spectrogram.
num_frames = spectrogram.shape[0] # or 'tau' in the paper.
num_freqs = spectrogram.shape[1] # or 'miu' in the paper.
mask_value = self.mask_value
if mask_value is None: # if no value was specified, use local mean.
mask_value = spectrogram.mean()
if num_frames == 0:
return spectrogram
if num_freqs < self.freq_mask_f:
return spectrogram
if self.time_warp_w > 0:
if 2 * self.time_warp_w < num_frames:
import cv2
w0 = np.random.randint(self.time_warp_w, num_frames - self.time_warp_w)
w = np.random.randint(0, self.time_warp_w)
upper, lower = distorted[:w0, :], distorted[w0:, :]
upper = cv2.resize(
upper, dsize=(num_freqs, w0 + w), interpolation=cv2.INTER_LINEAR
)
lower = cv2.resize(
lower,
dsize=(num_freqs, num_frames - w0 - w),
interpolation=cv2.INTER_LINEAR,
)
distorted = np.concatenate((upper, lower), axis=0)
for _i in range(self.freq_mask_n):
f = np.random.randint(0, self.freq_mask_f)
f0 = np.random.randint(0, num_freqs - f)
if f != 0:
distorted[:, f0 : f0 + f] = mask_value
max_time_mask_t = min(
self.time_mask_t, math.floor(num_frames * self.time_mask_p)
)
if max_time_mask_t < 1:
return distorted
for _i in range(self.time_mask_n):
t = np.random.randint(0, max_time_mask_t)
t0 = np.random.randint(0, num_frames - t)
if t != 0:
distorted[t0 : t0 + t, :] = mask_value
return distorted
import numpy as np
from fairseq.data.audio.feature_transforms import (
AudioFeatureTransform,
register_audio_feature_transform,
)
@register_audio_feature_transform("utterance_cmvn")
class UtteranceCMVN(AudioFeatureTransform):
"""Utterance-level CMVN (cepstral mean and variance normalization)"""
@classmethod
def from_config_dict(cls, config=None):
_config = {} if config is None else config
return UtteranceCMVN(
_config.get("norm_means", True),
_config.get("norm_vars", True),
)
def __init__(self, norm_means=True, norm_vars=True):
self.norm_means, self.norm_vars = norm_means, norm_vars
def __repr__(self):
return (
self.__class__.__name__
+ f"(norm_means={self.norm_means}, norm_vars={self.norm_vars})"
)
def __call__(self, x):
mean = x.mean(axis=0)
square_sums = (x ** 2).sum(axis=0)
if self.norm_means:
x = np.subtract(x, mean)
if self.norm_vars:
var = square_sums / x.shape[0] - mean ** 2
std = np.sqrt(np.maximum(var, 1e-10))
x = np.divide(x, std)
return x
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
import os
import sys
import numpy as np
import torch
import torch.nn.functional as F
from .. import FairseqDataset
logger = logging.getLogger(__name__)
class RawAudioDataset(FairseqDataset):
def __init__(
self,
sample_rate,
max_sample_size=None,
min_sample_size=None,
shuffle=True,
min_length=0,
pad=False,
normalize=False,
):
super().__init__()
self.sample_rate = sample_rate
self.sizes = []
self.max_sample_size = (
max_sample_size if max_sample_size is not None else sys.maxsize
)
self.min_sample_size = min_sample_size
self.min_length = min_length
self.pad = pad
self.shuffle = shuffle
self.normalize = normalize
def __getitem__(self, index):
raise NotImplementedError()
def __len__(self):
return len(self.sizes)
def postprocess(self, feats, curr_sample_rate):
if feats.dim() == 2:
feats = feats.mean(-1)
if curr_sample_rate != self.sample_rate:
raise Exception(f"sample rate: {curr_sample_rate}, need {self.sample_rate}")
assert feats.dim() == 1, feats.dim()
if self.normalize:
with torch.no_grad():
feats = F.layer_norm(feats, feats.shape)
return feats
def crop_to_max_size(self, wav, target_size):
size = len(wav)
diff = size - target_size
if diff <= 0:
return wav
start = np.random.randint(0, diff + 1)
end = size - diff + start
return wav[start:end]
def collater(self, samples):
samples = [s for s in samples if s["source"] is not None]
if len(samples) == 0:
return {}
sources = [s["source"] for s in samples]
sizes = [len(s) for s in sources]
if self.pad:
target_size = min(max(sizes), self.max_sample_size)
else:
target_size = min(min(sizes), self.max_sample_size)
collated_sources = sources[0].new_zeros(len(sources), target_size)
padding_mask = (
torch.BoolTensor(collated_sources.shape).fill_(False) if self.pad else None
)
for i, (source, size) in enumerate(zip(sources, sizes)):
diff = size - target_size
if diff == 0:
collated_sources[i] = source
elif diff < 0:
assert self.pad
collated_sources[i] = torch.cat(
[source, source.new_full((-diff,), 0.0)]
)
padding_mask[i, diff:] = True
else:
collated_sources[i] = self.crop_to_max_size(source, target_size)
input = {"source": collated_sources}
if self.pad:
input["padding_mask"] = padding_mask
return {"id": torch.LongTensor([s["id"] for s in samples]), "net_input": input}
def num_tokens(self, index):
return self.size(index)
def size(self, index):
"""Return an example's size as a float or tuple. This value is used when
filtering a dataset with ``--max-positions``."""
if self.pad:
return self.sizes[index]
return min(self.sizes[index], self.max_sample_size)
def ordered_indices(self):
"""Return an ordered list of indices. Batches will be constructed based
on this order."""
if self.shuffle:
order = [np.random.permutation(len(self))]
else:
order = [np.arange(len(self))]
order.append(self.sizes)
return np.lexsort(order)[::-1]
class FileAudioDataset(RawAudioDataset):
def __init__(
self,
manifest_path,
sample_rate,
max_sample_size=None,
min_sample_size=None,
shuffle=True,
min_length=0,
pad=False,
normalize=False,
):
super().__init__(
sample_rate=sample_rate,
max_sample_size=max_sample_size,
min_sample_size=min_sample_size,
shuffle=shuffle,
min_length=min_length,
pad=pad,
normalize=normalize,
)
self.fnames = []
skipped = 0
with open(manifest_path, "r") as f:
self.root_dir = f.readline().strip()
for line in f:
items = line.strip().split("\t")
assert len(items) == 2, line
sz = int(items[1])
if min_length is not None and sz < min_length:
skipped += 1
continue
self.fnames.append(items[0])
self.sizes.append(sz)
logger.info(f"loaded {len(self.fnames)}, skipped {skipped} samples")
def __getitem__(self, index):
import soundfile as sf
fname = os.path.join(self.root_dir, self.fnames[index])
wav, curr_sample_rate = sf.read(fname)
feats = torch.from_numpy(wav).float()
feats = self.postprocess(feats, curr_sample_rate)
return {"id": index, "source": feats}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment